[Bugfix] Check that number of images matches number of <|image|> tokens with mllama (#11939)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson
2025-01-10 16:26:00 -07:00
committed by GitHub
parent 8a579408f3
commit d45cbe70f5

View File

@ -123,6 +123,13 @@ def input_processor_for_mllama(
assert is_list_of(image_data, Image.Image)
num_image_tokens = dec_inputs['prompt_token_ids'].count(
MLLAMA_IMAGE_TOKEN_ID)
if num_image_tokens != len(image_data):
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({len(image_data)})")
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
@ -1493,6 +1500,8 @@ def convert_sparse_cross_attention_mask_to_dense(
dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1
tile_start += tile
assert ts != -1
assert td != 0
tile_range_for_decode.append((ts, ts + td))
seq_start += length