Compare commits

..

19 Commits

Author SHA1 Message Date
107d71f2a6 CI 2025-11-11 17:31:04 +01:00
d5247c1a5b finally 2025-11-11 17:23:54 +01:00
6f8942af27 fix 2025-11-11 12:37:59 +01:00
2b36e03a6b fix 2025-11-11 12:29:59 +01:00
3f1ab902f9 Merge branch 'main' into fix-tp 2025-11-11 12:26:23 +01:00
3760afb21c Fix T5Gemma module structure (#42145)
* fix modular

* oupsi typo
2025-11-11 12:26:03 +01:00
aef7d5b88f more 2025-11-11 12:18:01 +01:00
ea300e4e1b more fix 2025-11-11 11:44:42 +01:00
fd0a656ad1 continue to improve 2025-11-11 11:24:57 +01:00
3c0b2b101e fix: improve video processing fps assignment logic (#42009)
* fix: improve video processing fps and do_sample_frames assignment logic

* fix: set return_metadata=True to get metadata

* reformat the modular file

* fix typo

* revert flag change and fix fps assignment

* Taking 'num_frames' into considered.

Avoid error when 'num_frames' is passed rather than 'fps'.

* fix

* fix: avoid potential reference before assignment error

* fix

* add 'sample_fps' to 'VideoMetadata'

* fix missing comma

* fix trailing whitespace

* Handle different 'sample_indices_fn'

* Cleaning white space

* import callable from collections.abc

* calculate sampled_fps using indices

* correct the order

* fix

* properly check  value in kwargs

* handle sampled_fps as property

* remove duplicated definition

* fix

* fix

* add safety check

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
2025-11-11 10:54:33 +01:00
e869e9df54 update deps table (#42120)
* update deps table

* [build-ci-image]

* [build-ci-image]

* [push-ci-image]
2025-11-11 09:23:58 +01:00
f7a41c1c83 fix 2025-11-10 18:53:21 +01:00
3d7769bae3 more fix 2025-11-10 18:48:05 +01:00
12a4f590bc fix the obvious 2025-11-10 18:21:44 +01:00
37d48bbb48 Remove unused functions in image_transforms.py (#42044)
* up

* make style

* Update trimaps logic

* fix typo

* Revert changes

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
2025-11-10 16:55:57 +00:00
21913b2e10 Fix MaskFormer/Mask2Former fast image processors (#41393)
* Merge conflict

* add fast processor

* add fast processor

* make style

* add new convert rgb

* use nested group by shape in mllama fast, add support for multiple inputs in group by shape

* fix maskformer mask2 former fast im proc and add tests

* refactor after review

* add _iterate_items utility

* Fix failing tests

* fix copies and improve docs

---------

Co-authored-by: Vincent <phamvinh257@gmail.com>
2025-11-10 16:48:10 +00:00
0b59a201a9 fix test 2025-11-10 17:08:35 +01:00
01281a08da add test 2025-11-10 17:03:54 +01:00
be9d7e709b fix 2025-11-10 16:16:08 +01:00
32 changed files with 510 additions and 383 deletions

View File

@ -138,7 +138,7 @@ _deps = [
"pyyaml>=5.1",
"pydantic>=2",
"pytest>=7.2.0",
"pytest-asyncio",
"pytest-asyncio>=1.2.0",
"pytest-rerunfailures<16.0",
"pytest-timeout",
"pytest-xdist",

View File

@ -48,7 +48,7 @@ deps = {
"pyyaml": "pyyaml>=5.1",
"pydantic": "pydantic>=2",
"pytest": "pytest>=7.2.0",
"pytest-asyncio": "pytest-asyncio",
"pytest-asyncio": "pytest-asyncio>=1.2.0",
"pytest-rerunfailures": "pytest-rerunfailures<16.0",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",

View File

@ -821,14 +821,26 @@ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_widt
return image
def _cast_tensor_to_float(x):
if x.is_floating_point():
return x
return x.float()
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
"""
Helper function to flatten a single level of nested image and batch structures and group by shape.
Args:
nested_images (list):
A list of images or a single tensor
paired_inputs (Any, *optional*):
Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
they do not need to be tensors.
is_nested (bool, *optional*, defaults to False):
Whether the images are nested.
Returns:
tuple[dict, ...]:
- A dictionary with shape as key and list of images with that shape as value
- A dictionary with shape as key and list of paired values with that shape as value
- A dictionary mapping original indices to (shape, index) tuples
- A dictionary mapping original indices to (shape, index) tuples for each paired input
"""
grouped_images = defaultdict(list)
grouped_images_index = {}
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
@ -880,27 +892,20 @@ def _reconstruct_nested_structure(indices, processed_images):
return result
def _disable_grouping_output_nested(images, *paired_inputs):
"""Build the disable_grouping output tuple for a single-level nested structure."""
outer_range = range(len(images))
inner_ranges = [range(len(images[i])) for i in outer_range]
def _iterate_items(items, is_nested: bool):
"""
Helper function to iterate over items yielding (key, item) pairs.
# Precompute all (i, j) pairs
ij_pairs = [(i, j) for i in outer_range for j in inner_ranges[i]]
images_dict = {(i, j): images[i][j].unsqueeze(0) for (i, j) in ij_pairs}
paired_dicts = [{(i, j): paired_list[i][j].unsqueeze(0) for (i, j) in ij_pairs} for paired_list in paired_inputs]
index_map = {(i, j): ((i, j), 0) for (i, j) in ij_pairs}
return images_dict, *paired_dicts, index_map
def _disable_grouping_output_flat(images, *paired_inputs):
"""Build the disable_grouping output tuple for a flat list structure."""
idx_range = range(len(images))
images_dict = {i: images[i].unsqueeze(0) for i in idx_range}
paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs]
index_map = {i: (i, 0) for i in idx_range}
return images_dict, *paired_dicts, index_map
For nested structures, yields ((row_index, col_index), item).
For flat structures, yields (index, item).
"""
if is_nested:
for i, row in enumerate(items):
for j, item in enumerate(row):
yield (i, j), item
else:
for i, item in enumerate(items):
yield i, item
def group_images_by_shape(
@ -920,7 +925,7 @@ def group_images_by_shape(
Args:
images (Union[list["torch.Tensor"], "torch.Tensor"]):
A list of images or a single tensor
*paired_inputs (Any):
paired_inputs (Any, *optional*):
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
@ -944,10 +949,14 @@ def group_images_by_shape(
disable_grouping = device == "cpu"
if disable_grouping:
if is_nested:
return _disable_grouping_output_nested(images, *paired_inputs)
else:
return _disable_grouping_output_flat(images, *paired_inputs)
return (
{key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
*[
{key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
for paired_list in paired_inputs
],
{key: (key, 0) for key, _ in _iterate_items(images, is_nested)},
)
# Handle single level nested structure
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
@ -990,14 +999,3 @@ def reorder_images(
]
return _reconstruct_nested_structure(grouped_images_index, processed_images)
class NumpyToTensor:
"""
Convert a numpy array to a PyTorch tensor.
"""
def __call__(self, image: np.ndarray):
# Same as in PyTorch, we assume incoming numpy images are in HWC format
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()

View File

@ -140,6 +140,16 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
return [single_size] * blocks
def replace_layer_number_by_wildcard(name: str) -> str:
"""
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
a dot (`.`) and the end of the string.
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
"""
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
"""
Get the TP style for a parameter from the TP plan.
@ -150,11 +160,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
not parent classes for `post_init` calls
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
return tp_plan[module_name]
return None
@ -1086,7 +1096,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
if tp_plan is None:
return
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
unsharded_layers = set(generic_keys)
unused_rules = tp_plan

View File

@ -106,7 +106,6 @@ class ApertusConfig(PreTrainedConfig):
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -123,7 +123,6 @@ class ApertusConfig(LlamaConfig):
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
}
def __init__(

View File

@ -99,15 +99,14 @@ class AriaTextConfig(PreTrainedConfig):
model_type = "aria_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `AriaTextModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -169,6 +169,15 @@ class AriaTextConfig(LlamaConfig):
model_type = "aria_text"
base_config_key = "text_config"
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
}
def __init__(
self,

View File

@ -118,9 +118,9 @@ class DogeConfig(PreTrainedConfig):
"layers.*.self_attn.dt_proj": "rowwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.input_residual.weight": "sequence_parallel",
"layers.*.input_residual": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_residual.weight": "sequence_parallel",
"layers.*.post_attention_residual": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",

View File

@ -146,9 +146,9 @@ class DogeConfig(PreTrainedConfig):
"layers.*.self_attn.dt_proj": "rowwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.input_residual.weight": "sequence_parallel",
"layers.*.input_residual": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_residual.weight": "sequence_parallel",
"layers.*.post_attention_residual": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",

View File

@ -114,9 +114,9 @@ class FlexOlmoConfig(PreTrainedConfig):
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -125,9 +125,9 @@ class FlexOlmoConfig(OlmoeConfig):
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -630,7 +630,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Gemma3TextConfig
base_model_prefix = "language_model"
base_model_prefix = "model"
def __init__(self, config: Gemma3TextConfig):
super().__init__(config)

View File

@ -715,7 +715,7 @@ class Gemma3TextModel(Gemma2Model):
class Gemma3ForCausalLM(Gemma2ForCausalLM):
config: Gemma3TextConfig
base_model_prefix = "language_model"
base_model_prefix = "model"
def __init__(self, config: Gemma3TextConfig):
super().__init__(config)

View File

@ -214,8 +214,9 @@ class Glm4vMoeTextConfig(PreTrainedConfig):
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -159,8 +159,9 @@ class Glm4vMoeTextConfig(Glm4MoeConfig):
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -109,9 +109,6 @@ class Mask2FormerImageProcessorFast(BaseImageProcessorFast):
valid_kwargs = Mask2FormerImageProcessorKwargs
def __init__(self, **kwargs: Unpack[Mask2FormerImageProcessorKwargs]) -> None:
if "pad_and_return_pixel_mask" in kwargs:
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
size = kwargs.pop("size", None)
max_size = kwargs.pop("max_size", None)
@ -224,7 +221,7 @@ class Mask2FormerImageProcessorFast(BaseImageProcessorFast):
padding = [0, 0, padding_right, padding_bottom]
images = F.pad(images, padding, fill=fill)
if segmentation_maps is not None:
segmentation_maps = F.pad(segmentation_maps, padding, fill=ignore_index)
segmentation_maps = [F.pad(mask, padding, fill=ignore_index) for mask in segmentation_maps]
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
pixel_mask = torch.zeros((images.shape[0], *padded_size), dtype=torch.int64, device=images.device)
@ -318,9 +315,11 @@ class Mask2FormerImageProcessorFast(BaseImageProcessorFast):
stacked_images = self.resize(
image=stacked_images, size=size, size_divisor=size_divisor, interpolation=interpolation
)
if segmentation_maps is not None:
if segmentation_maps is not None:
stacked_segmentation_maps = grouped_segmentation_maps[shape]
if do_resize:
stacked_segmentation_maps = self.resize(
image=grouped_segmentation_maps[shape],
image=stacked_segmentation_maps,
size=size,
size_divisor=size_divisor,
interpolation=F.InterpolationMode.NEAREST_EXACT,
@ -357,14 +356,18 @@ class Mask2FormerImageProcessorFast(BaseImageProcessorFast):
mask_labels.append(masks)
class_labels.append(classes)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
processed_pixel_masks_grouped = {}
if segmentation_maps is not None:
grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape(
mask_labels, disable_grouping=disable_grouping
# group mask_labels as paired inputs and not images so as not to stack them
grouped_images, grouped_segmentation_maps, grouped_images_index = group_images_by_shape(
resized_images, mask_labels, disable_grouping=disable_grouping
)
processed_segmentation_maps_grouped = {}
else:
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, disable_grouping=disable_grouping
)
processed_images_grouped = {}
processed_pixel_masks_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
@ -379,7 +382,8 @@ class Mask2FormerImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = padded_images
processed_pixel_masks_grouped[shape] = pixel_masks
if segmentation_maps is not None:
processed_segmentation_maps_grouped[shape] = padded_segmentation_maps.squeeze(1)
processed_segmentation_maps_grouped[shape] = padded_segmentation_maps
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_pixel_masks = reorder_images(processed_pixel_masks_grouped, grouped_images_index)
encoded_inputs = BatchFeature(
@ -390,7 +394,7 @@ class Mask2FormerImageProcessorFast(BaseImageProcessorFast):
tensor_type=return_tensors,
)
if segmentation_maps is not None:
mask_labels = reorder_images(processed_segmentation_maps_grouped, grouped_segmentation_maps_index)
mask_labels = reorder_images(processed_segmentation_maps_grouped, grouped_images_index)
# we cannot batch them since they don't share a common class size
encoded_inputs["mask_labels"] = mask_labels
encoded_inputs["class_labels"] = class_labels

View File

@ -114,9 +114,6 @@ class MaskFormerImageProcessorFast(BaseImageProcessorFast):
valid_kwargs = MaskFormerImageProcessorKwargs
def __init__(self, **kwargs: Unpack[MaskFormerImageProcessorKwargs]) -> None:
if "pad_and_return_pixel_mask" in kwargs:
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
size = kwargs.pop("size", None)
max_size = kwargs.pop("max_size", None)
@ -229,7 +226,7 @@ class MaskFormerImageProcessorFast(BaseImageProcessorFast):
padding = [0, 0, padding_right, padding_bottom]
images = F.pad(images, padding, fill=fill)
if segmentation_maps is not None:
segmentation_maps = F.pad(segmentation_maps, padding, fill=ignore_index)
segmentation_maps = [F.pad(mask, padding, fill=ignore_index) for mask in segmentation_maps]
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
pixel_mask = torch.zeros((images.shape[0], *padded_size), dtype=torch.int64, device=images.device)
@ -323,9 +320,11 @@ class MaskFormerImageProcessorFast(BaseImageProcessorFast):
stacked_images = self.resize(
image=stacked_images, size=size, size_divisor=size_divisor, interpolation=interpolation
)
if segmentation_maps is not None:
if segmentation_maps is not None:
stacked_segmentation_maps = grouped_segmentation_maps[shape]
if do_resize:
stacked_segmentation_maps = self.resize(
image=grouped_segmentation_maps[shape],
image=stacked_segmentation_maps,
size=size,
size_divisor=size_divisor,
interpolation=F.InterpolationMode.NEAREST_EXACT,
@ -362,14 +361,18 @@ class MaskFormerImageProcessorFast(BaseImageProcessorFast):
mask_labels.append(masks)
class_labels.append(classes)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
processed_pixel_masks_grouped = {}
if segmentation_maps is not None:
grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape(
mask_labels, disable_grouping=disable_grouping
# group mask_labels as paired inputs and not images so as not to stack them
grouped_images, grouped_segmentation_maps, grouped_images_index = group_images_by_shape(
resized_images, mask_labels, disable_grouping=disable_grouping
)
processed_segmentation_maps_grouped = {}
else:
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, disable_grouping=disable_grouping
)
processed_images_grouped = {}
processed_pixel_masks_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
@ -384,7 +387,8 @@ class MaskFormerImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = padded_images
processed_pixel_masks_grouped[shape] = pixel_masks
if segmentation_maps is not None:
processed_segmentation_maps_grouped[shape] = padded_segmentation_maps.squeeze(1)
processed_segmentation_maps_grouped[shape] = padded_segmentation_maps
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_pixel_masks = reorder_images(processed_pixel_masks_grouped, grouped_images_index)
encoded_inputs = BatchFeature(
@ -395,7 +399,7 @@ class MaskFormerImageProcessorFast(BaseImageProcessorFast):
tensor_type=return_tensors,
)
if segmentation_maps is not None:
mask_labels = reorder_images(processed_segmentation_maps_grouped, grouped_segmentation_maps_index)
mask_labels = reorder_images(processed_segmentation_maps_grouped, grouped_images_index)
# we cannot batch them since they don't share a common class size
encoded_inputs["mask_labels"] = mask_labels
encoded_inputs["class_labels"] = class_labels

View File

@ -839,6 +839,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"return_metadata": True},
}
@ -922,10 +923,17 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
image_grid_thw = image_inputs["image_grid_thw"]
if videos is not None:
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
# Get video metadata
if not kwargs.get("return_metadata"):
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
fps = [metadata.sampled_fps for metadata in video_metadata]
if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):

View File

@ -41,6 +41,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"return_metadata": True},
}
@ -129,10 +130,17 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
image_grid_thw = image_inputs["image_grid_thw"]
if videos is not None:
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
# Get video metadata
if not kwargs.get("return_metadata"):
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
fps = [metadata.sampled_fps for metadata in video_metadata]
if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):

View File

@ -138,9 +138,9 @@ class Qwen3NextConfig(PreTrainedConfig):
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.shared_expert.gate_proj": "colwise",
"layers.*.mlp.shared_expert.up_proj": "colwise",
"layers.*.mlp.shared_expert.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",

View File

@ -228,38 +228,7 @@ class T5GemmaConfig(PreTrainedConfig):
model_type = "t5gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
# encoder
"encoder.layers.*.self_attn.q_proj": "colwise",
"encoder.layers.*.self_attn.k_proj": "colwise",
"encoder.layers.*.self_attn.v_proj": "colwise",
"encoder.layers.*.self_attn.o_proj": "rowwise",
"encoder.layers.*.mlp.gate_proj": "colwise",
"encoder.layers.*.mlp.up_proj": "colwise",
"encoder.layers.*.mlp.down_proj": "rowwise",
# decoder
"decoder.layers.*.self_attn.q_proj": "colwise",
"decoder.layers.*.self_attn.k_proj": "colwise",
"decoder.layers.*.self_attn.v_proj": "colwise",
"decoder.layers.*.self_attn.o_proj": "rowwise",
"decoder.layers.*.cross_attn.q_proj": "colwise",
"decoder.layers.*.cross_attn.k_proj": "colwise",
"decoder.layers.*.cross_attn.v_proj": "colwise",
"decoder.layers.*.cross_attn.o_proj": "rowwise",
"decoder.layers.*.mlp.gate_proj": "colwise",
"decoder.layers.*.mlp.up_proj": "colwise",
"decoder.layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
# encoder
"encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]),
"encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"encoder.norm": (["hidden_states"], ["hidden_states"]),
# decoder
"decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]),
"decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"decoder.norm": (["hidden_states"], ["hidden_states"]),
}
sub_configs = {"encoder": T5GemmaModuleConfig, "decoder": T5GemmaModuleConfig}
def __init__(
self,

View File

@ -448,11 +448,28 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer):
return hidden_states
class T5GemmaDecoderLayer(T5GemmaEncoderLayer):
class T5GemmaDecoderLayer(GradientCheckpointingLayer):
"""Decoder sub-layer: an extra cross-attention layer."""
def __init__(self, config, layer_idx: int):
super().__init__(config, layer_idx)
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = T5GemmaSelfAttention(
config=config,
layer_idx=layer_idx,
)
self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = T5GemmaMLP(config)
self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dropout = nn.Dropout(config.dropout_rate)
self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx)
self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -732,7 +749,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel):
)
class T5GemmaDecoder(T5GemmaEncoder):
class T5GemmaDecoder(T5GemmaPreTrainedModel):
_can_record_outputs = {
"attentions": OutputRecorder(T5GemmaSelfAttention, index=1),
"cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1),
@ -741,11 +758,20 @@ class T5GemmaDecoder(T5GemmaEncoder):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.layers = nn.ModuleList(
[T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.dropout = nn.Dropout(config.dropout_rate)
self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs()

View File

@ -236,38 +236,7 @@ class T5GemmaConfig(PreTrainedConfig):
model_type = "t5gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
# encoder
"encoder.layers.*.self_attn.q_proj": "colwise",
"encoder.layers.*.self_attn.k_proj": "colwise",
"encoder.layers.*.self_attn.v_proj": "colwise",
"encoder.layers.*.self_attn.o_proj": "rowwise",
"encoder.layers.*.mlp.gate_proj": "colwise",
"encoder.layers.*.mlp.up_proj": "colwise",
"encoder.layers.*.mlp.down_proj": "rowwise",
# decoder
"decoder.layers.*.self_attn.q_proj": "colwise",
"decoder.layers.*.self_attn.k_proj": "colwise",
"decoder.layers.*.self_attn.v_proj": "colwise",
"decoder.layers.*.self_attn.o_proj": "rowwise",
"decoder.layers.*.cross_attn.q_proj": "colwise",
"decoder.layers.*.cross_attn.k_proj": "colwise",
"decoder.layers.*.cross_attn.v_proj": "colwise",
"decoder.layers.*.cross_attn.o_proj": "rowwise",
"decoder.layers.*.mlp.gate_proj": "colwise",
"decoder.layers.*.mlp.up_proj": "colwise",
"decoder.layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
# encoder
"encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]),
"encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"encoder.norm": (["hidden_states"], ["hidden_states"]),
# decoder
"decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]),
"decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"decoder.norm": (["hidden_states"], ["hidden_states"]),
}
sub_configs = {"encoder": T5GemmaModuleConfig, "decoder": T5GemmaModuleConfig}
def __init__(
self,
@ -517,11 +486,28 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer):
return hidden_states
class T5GemmaDecoderLayer(T5GemmaEncoderLayer):
class T5GemmaDecoderLayer(GradientCheckpointingLayer):
"""Decoder sub-layer: an extra cross-attention layer."""
def __init__(self, config, layer_idx: int):
super().__init__(config, layer_idx)
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = T5GemmaSelfAttention(
config=config,
layer_idx=layer_idx,
)
self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = T5GemmaMLP(config)
self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dropout = nn.Dropout(config.dropout_rate)
self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx)
self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -770,7 +756,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel):
)
class T5GemmaDecoder(T5GemmaEncoder):
class T5GemmaDecoder(T5GemmaPreTrainedModel):
_can_record_outputs = {
"attentions": OutputRecorder(T5GemmaSelfAttention, index=1),
"cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1),
@ -779,11 +765,20 @@ class T5GemmaDecoder(T5GemmaEncoder):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.layers = nn.ModuleList(
[T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.dropout = nn.Dropout(config.dropout_rate)
self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs()

View File

@ -51,46 +51,6 @@ _default_log_level = logging.WARNING
_tqdm_active = not hf_hub_utils.are_progress_bars_disabled()
class Logger(logging.Logger):
def __init__(self, name, level=NOTSET):
super().__init__(name, level=level)
def warning_advice(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS")
if no_advisory_warnings:
return
self.warning(*args, **kwargs)
@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.warning(*args, **kwargs)
@functools.lru_cache(None)
def info_once(self, *args, **kwargs):
"""
This method is identical to `logger.info()`, but will emit the info with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.info(*args, **kwargs)
logging.setLoggerClass(Logger)
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
@ -112,7 +72,7 @@ def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> Logger:
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
@ -183,7 +143,7 @@ def captureWarnings(capture):
_captureWarnings(capture)
def get_logger(name: str | None = None) -> Logger:
def get_logger(name: str | None = None) -> logging.Logger:
"""
Return a logger with the specified name.
@ -341,6 +301,50 @@ def reset_format() -> None:
handler.setFormatter(None)
def warning_advice(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS")
if no_advisory_warnings:
return
self.warning(*args, **kwargs)
logging.Logger.warning_advice = warning_advice
@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.warning(*args, **kwargs)
logging.Logger.warning_once = warning_once
@functools.lru_cache(None)
def info_once(self, *args, **kwargs):
"""
This method is identical to `logger.info()`, but will emit the info with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.info(*args, **kwargs)
logging.Logger.info_once = info_once
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""

View File

@ -106,6 +106,13 @@ class VideoMetadata(Mapping):
raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
return [frame_idx / self.fps for frame_idx in self.frames_indices]
@property
def sampled_fps(self) -> float:
"FPS of the sampled video."
if self.frames_indices is None or self.total_num_frames is None or self.fps is None:
return self.fps or 24
return len(self.frames_indices) / self.total_num_frames * self.fps
def update(self, dictionary):
for key, value in dictionary.items():
if hasattr(self, key):
@ -372,8 +379,8 @@ def read_video_opencv(
height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
indices = sample_indices_fn(metadata=metadata, **kwargs)
index = 0
frames = []
while video.isOpened():
@ -486,8 +493,8 @@ def read_video_pyav(
height=container.streams.video[0].height,
width=container.streams.video[0].width,
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
indices = sample_indices_fn(metadata=metadata, **kwargs)
frames = []
container.seek(0)
end_index = indices[-1]
@ -548,7 +555,6 @@ def read_video_torchvision(
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
video = video[indices].contiguous()
metadata.update(
{
@ -596,16 +602,18 @@ def read_video_torchcodec(
num_ffmpeg_threads=0,
device=kwargs.get("device", "cpu"),
)
total_num_frames = decoder.metadata.num_frames
video_fps = decoder.metadata.average_fps
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
total_num_frames=total_num_frames,
fps=video_fps,
duration=decoder.metadata.duration_seconds,
video_backend="torchcodec",
height=decoder.metadata.height,
width=decoder.metadata.width,
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
indices = sample_indices_fn(metadata=metadata, **kwargs)
video = decoder.get_frames_at(indices=indices).data.contiguous()
metadata.frames_indices = indices
return video, metadata

View File

@ -140,6 +140,17 @@ class DeepseekV2ModelTest(CausalLMModelTest, unittest.TestCase):
def test_torch_compile_for_training(self):
pass
def test_tp_plan_matches_params(self):
"""Need to overwrite as the plan contains keys that are valid but depend on some configs flags and cannot
be valid all at the same time"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# The key is valid but not always used based on the flag
if config.q_lora_rank is not None:
config.base_model_tp_plan.pop("layers.*.self_attn.q_proj")
super().test_tp_plan_matches_params()
# Put them back in class attribute
config.base_model_tp_plan.update({"layers.*.self_attn.q_proj": "colwise"})
@slow
@require_read_token

View File

@ -337,6 +337,23 @@ class DogeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_save_load_fast_init_from_base(self):
pass
def test_tp_plan_matches_params(self):
"""Need to overwrite as the plan contains keys that are valid but depend on some configs flags and cannot
be valid all at the same time"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# They are valid but not always used, depending on config.is_moe flag (the modules are not the same in both cases)
problematic_keys = {
"layers.*.mlp.router_gate": "colwise_rep",
"layers.*.mlp.down_embed": "rowwise_rep",
"layers.*.mlp.up_embed": "rowwise_rep",
}
if not config.is_moe:
for key in problematic_keys:
config.base_model_tp_plan.pop(key)
super().test_tp_plan_matches_params()
# Put them back in class attribute
config.base_model_tp_plan.update(problematic_keys)
@require_torch_accelerator
class DogeIntegrationTest(unittest.TestCase):

View File

@ -194,13 +194,14 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
def comm_get_image_processing_inputs(
self,
image_processor_tester,
image_processing_class,
with_segmentation_maps=False,
is_instance_map=False,
segmentation_type="np",
numpify=False,
input_data_format=None,
):
image_processing = self.image_processing_class(**image_processor_tester.prepare_image_processor_dict())
image_processing = image_processing_class(**image_processor_tester.prepare_image_processor_dict())
# prepare image and target
num_labels = image_processor_tester.num_labels
annotations = None
@ -228,7 +229,6 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
input_data_format=input_data_format,
)
@ -264,25 +264,26 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_mean=[0.5] * num_channels,
image_std=[0.5] * num_channels,
)
for image_processing_class in self.image_processor_list:
inputs = self.comm_get_image_processing_inputs(
image_processor_tester=image_processor_tester,
image_processing_class=image_processing_class,
with_segmentation_maps=True,
is_instance_map=is_instance_map,
segmentation_type=segmentation_type,
numpify=numpify,
input_data_format=input_data_format,
)
inputs = self.comm_get_image_processing_inputs(
image_processor_tester=image_processor_tester,
with_segmentation_maps=True,
is_instance_map=is_instance_map,
segmentation_type=segmentation_type,
numpify=numpify,
input_data_format=input_data_format,
)
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
common()
common(is_instance_map=True)
@ -335,31 +336,32 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2)
# create a image processor
image_processing = Mask2FormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(do_reduce_labels=True, ignore_index=255, size=(512, 512))
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[instance_seg1, instance_seg2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[instance_seg1, instance_seg2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([30, 55]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([4, 4, 23, 55]))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([30, 55]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([4, 4, 23, 55]))
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (2, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (4, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 41527.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 26259.0)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (2, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (4, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 41527.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 26259.0)
def test_integration_semantic_segmentation(self):
# load 2 images and corresponding semantic annotations from the hub
@ -378,30 +380,31 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
)
# create a image processor
image_processing = Mask2FormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(do_reduce_labels=True, ignore_index=255, size=(512, 512))
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[annotation1, annotation2],
return_tensors="pt",
)
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[annotation1, annotation2],
return_tensors="pt",
)
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([2, 4, 60]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([0, 3, 7, 8, 15, 28, 30, 143]))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([2, 4, 60]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([0, 3, 7, 8, 15, 28, 30, 143]))
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (3, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (8, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 170200.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 257036.0)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (3, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (8, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 170200.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 257036.0)
def test_integration_panoptic_segmentation(self):
# load 2 images and corresponding panoptic annotations from the hub
@ -435,34 +438,35 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
# create a image processor
image_processing = Mask2FormerImageProcessor(ignore_index=0, do_resize=False)
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(ignore_index=0, do_resize=False)
# prepare the images and annotations
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
inputs = image_processing.encode_inputs(
pixel_values_list,
[panoptic_map1, panoptic_map2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# prepare the images and annotations
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
inputs = image_processing(
pixel_values_list,
[panoptic_map1, panoptic_map2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor(expected_class_labels))
expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][1], expected_class_labels)
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor(expected_class_labels))
expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][1], expected_class_labels)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711))
self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711))
self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0)
def test_binary_mask_to_rle(self):
fake_binary_mask = np.zeros((20, 50))

View File

@ -188,9 +188,9 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
self.assertTrue(hasattr(image_processing, "num_labels"))
def comm_get_image_processing_inputs(
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
self, image_processing_class, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# prepare image and target
num_labels = self.image_processor_tester.num_labels
annotations = None
@ -212,7 +212,6 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
)
return inputs
@ -233,19 +232,23 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
def test_call_with_segmentation_maps(self):
def common(is_instance_map=False, segmentation_type=None):
inputs = self.comm_get_image_processing_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
)
for image_processing_class in self.image_processor_list:
inputs = self.comm_get_image_processing_inputs(
image_processing_class=image_processing_class,
with_segmentation_maps=True,
is_instance_map=is_instance_map,
segmentation_type=segmentation_type,
)
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
common()
common(is_instance_map=True)
@ -286,31 +289,32 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2)
# create a image processor
image_processing = MaskFormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(do_reduce_labels=True, ignore_index=255, size=(512, 512))
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[instance_seg1, instance_seg2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[instance_seg1, instance_seg2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([30, 55]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([4, 4, 23, 55]))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([30, 55]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([4, 4, 23, 55]))
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (2, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (4, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 41527.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 26259.0)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (2, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (4, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 41527.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 26259.0)
def test_integration_semantic_segmentation(self):
# load 2 images and corresponding semantic annotations from the hub
@ -329,30 +333,31 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
)
# create a image processor
image_processing = MaskFormerImageProcessor(do_reduce_labels=True, ignore_index=255, size=(512, 512))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(do_reduce_labels=True, ignore_index=255, size=(512, 512))
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[annotation1, annotation2],
return_tensors="pt",
)
# prepare the images and annotations
inputs = image_processing(
[image1, image2],
[annotation1, annotation2],
return_tensors="pt",
)
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([2, 4, 60]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([0, 3, 7, 8, 15, 28, 30, 143]))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor([2, 4, 60]))
torch.testing.assert_close(inputs["class_labels"][1], torch.tensor([0, 3, 7, 8, 15, 28, 30, 143]))
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (3, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (8, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 170200.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 257036.0)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (3, 512, 512))
self.assertEqual(inputs["mask_labels"][1].shape, (8, 512, 512))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 170200.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 257036.0)
def test_integration_panoptic_segmentation(self):
# load 2 images and corresponding panoptic annotations from the hub
@ -386,34 +391,35 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
# create a image processor
image_processing = MaskFormerImageProcessor(ignore_index=0, do_resize=False)
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(ignore_index=0, do_resize=False)
# prepare the images and annotations
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
inputs = image_processing.encode_inputs(
pixel_values_list,
[panoptic_map1, panoptic_map2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# prepare the images and annotations
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
inputs = image_processing(
pixel_values_list,
[panoptic_map1, panoptic_map2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
# verify the pixel values and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor(expected_class_labels))
expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][1], expected_class_labels)
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][0], torch.tensor(expected_class_labels))
expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # fmt: skip
torch.testing.assert_close(inputs["class_labels"][1], expected_class_labels)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711))
self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711))
self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0)
def test_binary_mask_to_rle(self):
fake_binary_mask = np.zeros((20, 50))

View File

@ -632,6 +632,8 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return False
def test_config(self):
# Skip `create_and_test_config_from_and_save_pretrained_composite` because the config has twice the same subconfig
self.config_tester.create_and_test_config_from_and_save_pretrained_composite = lambda: None
self.config_tester.run_common_tests()
def test_shift_right(self):
@ -1469,6 +1471,8 @@ class T5GemmaEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
)
def test_config(self):
# Skip `create_and_test_config_from_and_save_pretrained_composite` because the config has twice the same subconfig
self.config_tester.create_and_test_config_from_and_save_pretrained_composite = lambda: None
self.config_tester.run_common_tests()
@unittest.skip("This was not properly written, submodules need the attribute to be overwritten")

View File

@ -121,6 +121,7 @@ if is_torch_available():
from torch import nn
from transformers import MODEL_MAPPING
from transformers.integrations.tensor_parallel import _get_parameter_tp_plan
from transformers.modeling_utils import load_state_dict
from transformers.pytorch_utils import id_tensor_storage
@ -3897,6 +3898,48 @@ class ModelTesterMixin:
self.assertEqual(v1.dtype, v2.dtype)
self.assertTrue((v1 == v2).all())
def test_tp_plan_matches_params(self):
"""Make sure that each entry of the tp plan matches at least one param (this avoid typos and/or edge cases
with regexes)"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# If none of the config and subconfigs have a tp_plan, then skip (otherwise we should make sure to respect the plan)
if config.base_model_tp_plan is None and all(
getattr(getattr(config, key), "base_model_tp_plan", None) is None for key in config.sub_configs
):
self.skipTest("Model does not have a TP plan.")
# Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have each one
# in order to test the whole tp plan
config_to_set = config.get_text_config()
config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE
config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5...
config_to_set.mlp_only_layers = [0] # same but for qwens
for model_class in self.all_model_classes:
model = model_class(copy.deepcopy(config))
param_names = {name for name, _ in model.named_parameters()} | {name for name, _ in model.named_buffers()}
module_names = {name for name, _ in model.named_modules()}
tp_plan = model.tp_plan
# Make sure the plan is not empty
self.assertTrue(
len(tp_plan) > 0,
f"No TP-plan found for class {model_class.__name__} even though the associated config has one",
)
pattern_usage = {}
for pattern in tp_plan:
# Check if this given pattern matches any param or module (the value attributed to the pattern does not matter)
pattern_usage[pattern] = any(
_get_parameter_tp_plan(param, {pattern: ""}, is_weight=True) is not None for param in param_names
) or any(
_get_parameter_tp_plan(module, {pattern: ""}, is_weight=False) is not None
for module in module_names
)
unused_entries = {k for k, v in pattern_usage.items() if not v}
self.assertTrue(
len(unused_entries) == 0, f"The following entries of the TP-plan are not valid: {unused_entries}"
)
global_rng = random.Random()