Gemma3 fixes (#41572)

* Multiple device error fix

* FA2 equivalence fix

* Move the train fwd in cfg test

* Style

* Added comment

* Made the comment more clear
This commit is contained in:
Rémi Ouazan
2025-10-14 18:33:27 +02:00
committed by GitHub
parent 4c8d293599
commit 9e4199ede3
4 changed files with 29 additions and 7 deletions

View File

@ -19,6 +19,7 @@ import unittest
import pytest
from parameterized import parameterized
from pytest import mark
from transformers import (
AutoModelForCausalLM,
@ -33,9 +34,11 @@ from transformers.testing_utils import (
is_flash_attn_2_available,
require_deterministic_for_xpu,
require_flash_attn,
require_flash_attn_3,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_large_accelerator,
slow,
torch_device,
@ -342,6 +345,20 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
self.flash_attn_from_config(attn_implementation="flash_attention_2", test_fwd_in_train=False)
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attn_3_from_config(self):
self.flash_attn_from_config(attn_implementation="flash_attention_3", test_fwd_in_train=False)
@slow
@require_torch_accelerator

View File

@ -2976,7 +2976,7 @@ class ModelTesterMixin:
def flash_attn_inference_equivalence(
self, attn_implementation: str, padding_side: str, atol: float = 4e-2, rtol: float = 4e-2
):
) -> None:
r"""
Tests the equivalence between the eager and flash attention implementations.
This test is only for inference and runs with `dtype=torch.bfloat16`.
@ -3114,9 +3114,6 @@ class ModelTesterMixin:
torch.testing.assert_close(logits_1_eager, logits_1_fa, atol=atol, rtol=rtol)
if padding_side == "left":
torch.testing.assert_close(logits_2_eager[1:], logits_2_fa[1:], atol=atol, rtol=rtol)
# Check it can run in training mode
model.train()
_ = model(**second_inputs)
else:
torch.testing.assert_close(logits_2_eager[:-1], logits_2_fa[:-1], atol=atol, rtol=rtol)
@ -3651,7 +3648,7 @@ class ModelTesterMixin:
assert not loss.isnan().any()
def flash_attn_from_config(self, attn_implementation: str):
def flash_attn_from_config(self, attn_implementation: str, test_fwd_in_train: bool = True):
r"""
Tests if the model can be loaded with `attn_implementation` from the config and if the
weights are not randomly initialized.
@ -3669,6 +3666,14 @@ class ModelTesterMixin:
config, attn_implementation=attn_implementation, dtype=torch.bfloat16
).to(torch_device)
# By default, we perform the forward pass in train mode, because it's more sctrict than eval mode. If the
# forward pass is successful in train mode, it will also be successful in eval mode. But since some models
# (eg. gemma3) need different inputs in train mode we have the option to test the forward pass in eval mode.
if test_fwd_in_train:
fa_model = fa_model.train()
else:
fa_model = fa_model.eval()
dummy_input = inputs_dict[fa_model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)