mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Gemma3 fixes (#41572)
* Multiple device error fix * FA2 equivalence fix * Move the train fwd in cfg test * Style * Added comment * Made the comment more clear
This commit is contained in:
@ -19,6 +19,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
@ -33,9 +34,11 @@ from transformers.testing_utils import (
|
||||
is_flash_attn_2_available,
|
||||
require_deterministic_for_xpu,
|
||||
require_flash_attn,
|
||||
require_flash_attn_3,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_large_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
@ -342,6 +345,20 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_from_config(self):
|
||||
self.flash_attn_from_config(attn_implementation="flash_attention_2", test_fwd_in_train=False)
|
||||
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attn_3_from_config(self):
|
||||
self.flash_attn_from_config(attn_implementation="flash_attention_3", test_fwd_in_train=False)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
@ -2976,7 +2976,7 @@ class ModelTesterMixin:
|
||||
|
||||
def flash_attn_inference_equivalence(
|
||||
self, attn_implementation: str, padding_side: str, atol: float = 4e-2, rtol: float = 4e-2
|
||||
):
|
||||
) -> None:
|
||||
r"""
|
||||
Tests the equivalence between the eager and flash attention implementations.
|
||||
This test is only for inference and runs with `dtype=torch.bfloat16`.
|
||||
@ -3114,9 +3114,6 @@ class ModelTesterMixin:
|
||||
torch.testing.assert_close(logits_1_eager, logits_1_fa, atol=atol, rtol=rtol)
|
||||
if padding_side == "left":
|
||||
torch.testing.assert_close(logits_2_eager[1:], logits_2_fa[1:], atol=atol, rtol=rtol)
|
||||
# Check it can run in training mode
|
||||
model.train()
|
||||
_ = model(**second_inputs)
|
||||
else:
|
||||
torch.testing.assert_close(logits_2_eager[:-1], logits_2_fa[:-1], atol=atol, rtol=rtol)
|
||||
|
||||
@ -3651,7 +3648,7 @@ class ModelTesterMixin:
|
||||
|
||||
assert not loss.isnan().any()
|
||||
|
||||
def flash_attn_from_config(self, attn_implementation: str):
|
||||
def flash_attn_from_config(self, attn_implementation: str, test_fwd_in_train: bool = True):
|
||||
r"""
|
||||
Tests if the model can be loaded with `attn_implementation` from the config and if the
|
||||
weights are not randomly initialized.
|
||||
@ -3669,6 +3666,14 @@ class ModelTesterMixin:
|
||||
config, attn_implementation=attn_implementation, dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
# By default, we perform the forward pass in train mode, because it's more sctrict than eval mode. If the
|
||||
# forward pass is successful in train mode, it will also be successful in eval mode. But since some models
|
||||
# (eg. gemma3) need different inputs in train mode we have the option to test the forward pass in eval mode.
|
||||
if test_fwd_in_train:
|
||||
fa_model = fa_model.train()
|
||||
else:
|
||||
fa_model = fa_model.eval()
|
||||
|
||||
dummy_input = inputs_dict[fa_model.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
Reference in New Issue
Block a user