Add sequence classification support for small Gemma 3 text models (#40562)

* add seq class for gemma3 text model

* add Gemma3TextForSequenceClassification to modeling file

* After run make fixup

* let's just check

* thiis is why it was crashing, tests were just failing...

* skip it, tested only for seq clf

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
This commit is contained in:
Abdelrahman Kaseb
2025-09-04 12:44:59 +03:00
committed by GitHub
parent 30a4b8707d
commit 25b4a0d8ae
5 changed files with 53 additions and 3 deletions

View File

@ -273,3 +273,8 @@ visualizer("<img>What is shown in this image?")
[[autodoc]] Gemma3ForSequenceClassification
- forward
## Gemma3TextForSequenceClassification
[[autodoc]] Gemma3TextForSequenceClassification
- forward

View File

@ -1207,6 +1207,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gemma", "GemmaForSequenceClassification"),
("gemma2", "Gemma2ForSequenceClassification"),
("gemma3", "Gemma3ForSequenceClassification"),
("gemma3_text", "Gemma3TextForSequenceClassification"),
("glm", "GlmForSequenceClassification"),
("glm4", "Glm4ForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),

View File

@ -33,7 +33,7 @@ from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -1301,6 +1301,15 @@ class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
)
class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
"""
Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
It uses the generic sequence classification implementation for efficiency and consistency.
"""
config: Gemma3TextConfig
__all__ = [
"Gemma3PreTrainedModel",
"Gemma3TextModel",
@ -1308,4 +1317,5 @@ __all__ = [
"Gemma3ForConditionalGeneration",
"Gemma3Model",
"Gemma3ForSequenceClassification",
"Gemma3TextForSequenceClassification",
]

View File

@ -26,7 +26,7 @@ from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@ -1170,6 +1170,15 @@ class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
)
class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
"""
Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
It uses the generic sequence classification implementation for efficiency and consistency.
"""
config: Gemma3TextConfig
__all__ = [
"Gemma3Config",
"Gemma3TextConfig",
@ -1179,4 +1188,5 @@ __all__ = [
"Gemma3ForConditionalGeneration",
"Gemma3Model",
"Gemma3ForSequenceClassification",
"Gemma3TextForSequenceClassification",
]

View File

@ -56,6 +56,7 @@ if is_torch_available():
Gemma3ForSequenceClassification,
Gemma3Model,
Gemma3Processor,
Gemma3TextForSequenceClassification,
Gemma3TextModel,
)
from transformers.pytorch_utils import is_torch_greater_or_equal
@ -70,7 +71,9 @@ class Gemma3ModelTester(GemmaModelTester):
@require_torch
class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3TextModel, Gemma3ForCausalLM) if is_torch_available() else ()
all_model_classes = (
(Gemma3TextModel, Gemma3ForCausalLM, Gemma3TextForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (Gemma3ForCausalLM,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
@ -97,6 +100,12 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip(
"Gemma3 has no base model prefix which causes issues when loading base model from saved task model checkpoint"
)
def test_load_with_mismatched_shapes(self):
pass
def test_generation_beyond_sliding_window_tiny_model(self):
"""Test generation with a tiny randomly initialised model whose input length is larger than the `sliding_window`.
The model is configured with both `full_attention` and `sliding_attention` layers to make sure the hybrid cache
@ -143,6 +152,21 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
EXPECTED_OUTPUT = torch.tensor([[90109, 90109, 90109, 83191, 83191], [246901, 69832, 69832, 69832, 62288]])
torch.testing.assert_close(generated_sequences, EXPECTED_OUTPUT)
def test_gemma3_text_sequence_classification_model(self):
"""Test the text-only sequence classification model."""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_labels)
model = Gemma3TextForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, config.num_labels))
class Gemma3Vision2TextModelTester:
def __init__(