Compare commits

...

5 Commits

Author SHA1 Message Date
da767ce0cb Merge branch 'main' into mllama_integration_tests 2025-07-30 12:58:11 +02:00
49f0f18f1e copies 2025-07-30 12:35:37 +02:00
e0bcc4a10f ruff 2025-07-30 12:34:07 +02:00
f722aae48f fix integration tests 2025-07-30 12:31:10 +02:00
f202533977 fix integration tests 2025-07-30 12:21:20 +02:00

View File

@ -350,13 +350,18 @@ class MllamaVisionEncoder(nn.Module):
[What are attention masks?](../glossary#attention-mask)
"""
encoder_states = ()
for encoder_layer in self.layers:
encoder_states = encoder_states + (hidden_states,)
hidden_states = encoder_layer(
hidden_state=hidden_states,
attention_mask=attention_mask,
)
return BaseModelOutput(last_hidden_state=hidden_states)
encoder_states = encoder_states + (hidden_states,)
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
@ -1106,7 +1111,7 @@ class MllamaVisionModel(MllamaPreTrainedModel):
hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim)
# Collect intermediate layer outputs from encoder output
all_intermediate_hidden_states = [output.last_hidden_state for _ in self.intermediate_layers_indices]
all_intermediate_hidden_states = [output.hidden_states[i] for i in self.intermediate_layers_indices]
intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1)
# Remove padding from intermediate hidden states