[FA] Fix some model tests (#40350)

* fix

* cleanup, revert aimv2 fa changes

* fix aria

* i searched a long time but the cross dependency is for the recent models so...

* this was something... evolla

* fix modernbert decoder + make fa test more robust

* nit
This commit is contained in:
Anton Vlasjuk
2025-08-21 18:08:21 +02:00
committed by GitHub
parent f46f29dd7c
commit cb1df4d26a
15 changed files with 58 additions and 22 deletions

View File

@ -630,6 +630,7 @@ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
class Aimv2Model(Aimv2PreTrainedModel):
config: Aimv2Config
_no_split_modules = ["Aimv2TextEmbeddings", "Aimv2EncoderLayer", "Aimv2VisionEmbeddings"]
_supports_flash_attn = True
def __init__(self, config: Aimv2Config):
super().__init__(config)

View File

@ -614,6 +614,8 @@ class Aimv2TextModel(Aimv2PreTrainedModel):
@auto_docstring
class Aimv2Model(CLIPModel, nn.Module):
_supports_flash_attn = True
def __init__(self, config: Aimv2Config):
nn.Module().__init__(config)

View File

@ -632,7 +632,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = False
_supports_flash_attn = True
_supports_sdpa = True
_supports_attention_backend = True

View File

@ -1283,7 +1283,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = False
_supports_flash_attn = True
_supports_sdpa = True
_supports_attention_backend = True

View File

@ -673,6 +673,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
config: CLIPTextConfig
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
@ -830,6 +831,7 @@ class CLIPVisionModel(CLIPPreTrainedModel):
class CLIPModel(CLIPPreTrainedModel):
config: CLIPConfig
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
def __init__(self, config: CLIPConfig):
super().__init__(config)

View File

@ -715,7 +715,6 @@ class EvollaSaProtPooler(nn.Module):
class EvollaSaProtPreTrainedModel(PreTrainedModel):
config: SaProtConfig
_no_split_modules = ["EvollaSaProtLayer"]
_supports_flash_attn = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -1511,9 +1510,9 @@ class EvollaPreTrainedModel(PreTrainedModel):
"EvollaSequenceAlignerCrossAttention",
]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder`
_supports_sdpa = True
_supports_flex_attn = True
_supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder`
_can_compile_fullgraph = True
_supports_attention_backend = False

View File

@ -193,7 +193,6 @@ class EvollaSaProtPooler(EsmPooler):
class EvollaSaProtPreTrainedModel(PreTrainedModel):
config: SaProtConfig
_no_split_modules = ["EvollaSaProtLayer"]
_supports_flash_attn = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -772,6 +771,8 @@ class EvollaDecoderLayer(LlamaDecoderLayer):
class EvollaPreTrainedModel(LlamaPreTrainedModel):
_supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder`
_supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder`
_supports_attention_backend = False
_no_split_modules = [
"EvollaDecoderLayer",

View File

@ -283,7 +283,7 @@ class GraniteSpeechCTCEncoder(nn.Module):
class GraniteSpeechPreTrainedModel(PreTrainedModel):
config: GraniteSpeechConfig
_supports_flash_attn = True
_supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this
_supports_sdpa = True
def _init_weights(self, module: nn.Module):

View File

@ -883,7 +883,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True
_supports_flash_attn = True
_supports_flash_attn = False # only eager/sdpa creation is supported
_can_compile_fullgraph = False # IDEFICS cannot compile due to dynamic control flow when checking inputs
_supports_attention_backend = True

View File

@ -545,6 +545,7 @@ class MetaClip2TextModel(MetaClip2PreTrainedModel):
config: MetaClip2TextConfig
_no_split_modules = ["MetaClip2TextEmbeddings", "MetaClip2EncoderLayer"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
def __init__(self, config: MetaClip2TextConfig):
super().__init__(config)
@ -789,6 +790,7 @@ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
class MetaClip2Model(MetaClip2PreTrainedModel):
config: MetaClip2Config
_no_split_modules = ["MetaClip2TextEmbeddings", "MetaClip2EncoderLayer", "MetaClip2VisionEmbeddings"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
def __init__(self, config: MetaClip2Config):
super().__init__(config)

View File

@ -221,12 +221,8 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
@auto_docstring
class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
config: ModernBertDecoderConfig
base_model_prefix = "model"
_skip_keys_device_placement = ["past_key_values"]
_no_split_modules = ["ModernBertDecoderLayer"]
_supports_flash_attn = True
_supports_sdpa = False
_supports_gradient_checkpointing = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
@ -280,6 +276,20 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _check_and_adjust_attn_implementation(
self, attn_implementation: Optional[str], is_init_check: bool = False
) -> str:
"""We overwrite this to make sdpa the first selection again if nothing was requested."""
try:
attn_implementation = (
"sdpa" if attn_implementation is None and self._sdpa_can_dispatch() else attn_implementation
)
except (ValueError, ImportError):
pass
return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check)
@auto_docstring
class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):

View File

@ -398,12 +398,8 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
@auto_docstring
class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
config: ModernBertDecoderConfig
base_model_prefix = "model"
_skip_keys_device_placement = ["past_key_values"]
_no_split_modules = ["ModernBertDecoderLayer"]
_supports_flash_attn = True
_supports_sdpa = False
_supports_gradient_checkpointing = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
@ -457,6 +453,20 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _check_and_adjust_attn_implementation(
self, attn_implementation: Optional[str], is_init_check: bool = False
) -> str:
"""We overwrite this to make sdpa the first selection again if nothing was requested."""
try:
attn_implementation = (
"sdpa" if attn_implementation is None and self._sdpa_can_dispatch() else attn_implementation
)
except (ValueError, ImportError):
pass
return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check)
@auto_docstring
class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):

View File

@ -622,7 +622,6 @@ class SiglipTextTransformer(nn.Module):
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, config.projection_size)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@can_return_tuple
@auto_docstring
@ -649,7 +648,10 @@ class SiglipTextTransformer(nn.Module):
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
if attention_mask is not None and not self._use_flash_attention_2:
uses_flash_attention = "flash" in self.config._attn_implementation
if uses_flash_attention:
attention_mask = None
elif attention_mask is not None and not uses_flash_attention:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

View File

@ -647,7 +647,6 @@ class Siglip2TextTransformer(nn.Module):
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, config.projection_size)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@can_return_tuple
@auto_docstring
@ -674,7 +673,10 @@ class Siglip2TextTransformer(nn.Module):
# note: Siglip2's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
if attention_mask is not None and not self._use_flash_attention_2:
uses_flash_attention = "flash" in self.config._attn_implementation
if uses_flash_attention:
attention_mask = None
elif attention_mask is not None and not uses_flash_attention:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

View File

@ -3501,6 +3501,11 @@ class ModelTesterMixin:
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# Some models have support for FA but not SDPA - making sure we have a valid attention
initial_attention_implementation = "sdpa"
if model.config._attn_implementation != "sdpa":
initial_attention_implementation = "eager"
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
@ -3526,7 +3531,7 @@ class ModelTesterMixin:
model.set_attn_implementation(attn_implementation)
outputs_fa = model(dummy_input, output_hidden_states=True)
model.set_attn_implementation("sdpa")
model.set_attn_implementation(initial_attention_implementation)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
@ -3563,7 +3568,7 @@ class ModelTesterMixin:
model.set_attn_implementation(attn_implementation)
outputs_fa = model(dummy_input, **other_inputs)
model.set_attn_implementation("sdpa")
model.set_attn_implementation(initial_attention_implementation)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder