mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[FA
] Fix some model tests (#40350)
* fix * cleanup, revert aimv2 fa changes * fix aria * i searched a long time but the cross dependency is for the recent models so... * this was something... evolla * fix modernbert decoder + make fa test more robust * nit
This commit is contained in:
@ -630,6 +630,7 @@ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
|
||||
class Aimv2Model(Aimv2PreTrainedModel):
|
||||
config: Aimv2Config
|
||||
_no_split_modules = ["Aimv2TextEmbeddings", "Aimv2EncoderLayer", "Aimv2VisionEmbeddings"]
|
||||
_supports_flash_attn = True
|
||||
|
||||
def __init__(self, config: Aimv2Config):
|
||||
super().__init__(config)
|
||||
|
@ -614,6 +614,8 @@ class Aimv2TextModel(Aimv2PreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class Aimv2Model(CLIPModel, nn.Module):
|
||||
_supports_flash_attn = True
|
||||
|
||||
def __init__(self, config: Aimv2Config):
|
||||
nn.Module().__init__(config)
|
||||
|
||||
|
@ -632,7 +632,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = False
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
|
||||
_supports_attention_backend = True
|
||||
|
@ -1283,7 +1283,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = False
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
|
||||
_supports_attention_backend = True
|
||||
|
@ -673,6 +673,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
||||
config: CLIPTextConfig
|
||||
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
@ -830,6 +831,7 @@ class CLIPVisionModel(CLIPPreTrainedModel):
|
||||
class CLIPModel(CLIPPreTrainedModel):
|
||||
config: CLIPConfig
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
|
||||
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
@ -715,7 +715,6 @@ class EvollaSaProtPooler(nn.Module):
|
||||
class EvollaSaProtPreTrainedModel(PreTrainedModel):
|
||||
config: SaProtConfig
|
||||
_no_split_modules = ["EvollaSaProtLayer"]
|
||||
_supports_flash_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -1511,9 +1510,9 @@ class EvollaPreTrainedModel(PreTrainedModel):
|
||||
"EvollaSequenceAlignerCrossAttention",
|
||||
]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder`
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder`
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = False
|
||||
|
@ -193,7 +193,6 @@ class EvollaSaProtPooler(EsmPooler):
|
||||
class EvollaSaProtPreTrainedModel(PreTrainedModel):
|
||||
config: SaProtConfig
|
||||
_no_split_modules = ["EvollaSaProtLayer"]
|
||||
_supports_flash_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -772,6 +771,8 @@ class EvollaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
|
||||
class EvollaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder`
|
||||
_supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder`
|
||||
_supports_attention_backend = False
|
||||
_no_split_modules = [
|
||||
"EvollaDecoderLayer",
|
||||
|
@ -283,7 +283,7 @@ class GraniteSpeechCTCEncoder(nn.Module):
|
||||
class GraniteSpeechPreTrainedModel(PreTrainedModel):
|
||||
config: GraniteSpeechConfig
|
||||
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
|
@ -883,7 +883,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
||||
_supports_sdpa = True
|
||||
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn = False # only eager/sdpa creation is supported
|
||||
_can_compile_fullgraph = False # IDEFICS cannot compile due to dynamic control flow when checking inputs
|
||||
_supports_attention_backend = True
|
||||
|
||||
|
@ -545,6 +545,7 @@ class MetaClip2TextModel(MetaClip2PreTrainedModel):
|
||||
config: MetaClip2TextConfig
|
||||
|
||||
_no_split_modules = ["MetaClip2TextEmbeddings", "MetaClip2EncoderLayer"]
|
||||
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
|
||||
|
||||
def __init__(self, config: MetaClip2TextConfig):
|
||||
super().__init__(config)
|
||||
@ -789,6 +790,7 @@ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
|
||||
class MetaClip2Model(MetaClip2PreTrainedModel):
|
||||
config: MetaClip2Config
|
||||
_no_split_modules = ["MetaClip2TextEmbeddings", "MetaClip2EncoderLayer", "MetaClip2VisionEmbeddings"]
|
||||
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
|
||||
|
||||
def __init__(self, config: MetaClip2Config):
|
||||
super().__init__(config)
|
||||
|
@ -221,12 +221,8 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
||||
@auto_docstring
|
||||
class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
||||
config: ModernBertDecoderConfig
|
||||
base_model_prefix = "model"
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_no_split_modules = ["ModernBertDecoderLayer"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = False
|
||||
_supports_gradient_checkpointing = True
|
||||
_can_compile_fullgraph = False
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
@ -280,6 +276,20 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _check_and_adjust_attn_implementation(
|
||||
self, attn_implementation: Optional[str], is_init_check: bool = False
|
||||
) -> str:
|
||||
"""We overwrite this to make sdpa the first selection again if nothing was requested."""
|
||||
|
||||
try:
|
||||
attn_implementation = (
|
||||
"sdpa" if attn_implementation is None and self._sdpa_can_dispatch() else attn_implementation
|
||||
)
|
||||
except (ValueError, ImportError):
|
||||
pass
|
||||
|
||||
return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
|
@ -398,12 +398,8 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
||||
@auto_docstring
|
||||
class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
||||
config: ModernBertDecoderConfig
|
||||
base_model_prefix = "model"
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_no_split_modules = ["ModernBertDecoderLayer"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = False
|
||||
_supports_gradient_checkpointing = True
|
||||
_can_compile_fullgraph = False
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
@ -457,6 +453,20 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _check_and_adjust_attn_implementation(
|
||||
self, attn_implementation: Optional[str], is_init_check: bool = False
|
||||
) -> str:
|
||||
"""We overwrite this to make sdpa the first selection again if nothing was requested."""
|
||||
|
||||
try:
|
||||
attn_implementation = (
|
||||
"sdpa" if attn_implementation is None and self._sdpa_can_dispatch() else attn_implementation
|
||||
)
|
||||
except (ValueError, ImportError):
|
||||
pass
|
||||
|
||||
return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||
|
@ -622,7 +622,6 @@ class SiglipTextTransformer(nn.Module):
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.head = nn.Linear(embed_dim, config.projection_size)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@ -649,7 +648,10 @@ class SiglipTextTransformer(nn.Module):
|
||||
|
||||
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
|
||||
# expand attention_mask
|
||||
if attention_mask is not None and not self._use_flash_attention_2:
|
||||
uses_flash_attention = "flash" in self.config._attn_implementation
|
||||
if uses_flash_attention:
|
||||
attention_mask = None
|
||||
elif attention_mask is not None and not uses_flash_attention:
|
||||
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
|
@ -647,7 +647,6 @@ class Siglip2TextTransformer(nn.Module):
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.head = nn.Linear(embed_dim, config.projection_size)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@ -674,7 +673,10 @@ class Siglip2TextTransformer(nn.Module):
|
||||
|
||||
# note: Siglip2's text model does not use a causal mask, unlike the original CLIP model.
|
||||
# expand attention_mask
|
||||
if attention_mask is not None and not self._use_flash_attention_2:
|
||||
uses_flash_attention = "flash" in self.config._attn_implementation
|
||||
if uses_flash_attention:
|
||||
attention_mask = None
|
||||
elif attention_mask is not None and not uses_flash_attention:
|
||||
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
|
@ -3501,6 +3501,11 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
# Some models have support for FA but not SDPA - making sure we have a valid attention
|
||||
initial_attention_implementation = "sdpa"
|
||||
if model.config._attn_implementation != "sdpa":
|
||||
initial_attention_implementation = "eager"
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
@ -3526,7 +3531,7 @@ class ModelTesterMixin:
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, output_hidden_states=True)
|
||||
|
||||
model.set_attn_implementation("sdpa")
|
||||
model.set_attn_implementation(initial_attention_implementation)
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
@ -3563,7 +3568,7 @@ class ModelTesterMixin:
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, **other_inputs)
|
||||
|
||||
model.set_attn_implementation("sdpa")
|
||||
model.set_attn_implementation(initial_attention_implementation)
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
|
Reference in New Issue
Block a user