Fix fp32_ln for various models (#41605)

* Add is_causal to KosmosTextAttention

* Move get target_dtype to be imported elsewhere

* Fix fp32 flash attention bug in bark

* Fix is_causal in mllama

* Fix fp32 issue on StableLM

* Fix repo-consistency
This commit is contained in:
Rémi Ouazan
2025-10-16 14:18:49 +02:00
committed by GitHub
parent b9bd8c45a1
commit 2935a1be19
6 changed files with 26 additions and 10 deletions

View File

@ -11,6 +11,19 @@ logger = logging.get_logger(__name__)
_use_top_left_mask = flash_attn_supports_top_left_mask()
def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype:
"""If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise."""
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
return torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
return module.config._pre_quantization_dtype
else:
return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
return None
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
@ -48,15 +61,7 @@ def flash_attention_forward(
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
target_dtype = None
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
target_dtype = module.config._pre_quantization_dtype
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
target_dtype = get_target_dtype(query, module)
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
is_causal = kwargs.pop("is_causal", None)

View File

@ -57,6 +57,7 @@ from .generation_configuration_bark import (
if is_flash_attn_available():
from ...integrations.flash_attention import get_target_dtype
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -78,6 +79,7 @@ class BarkSelfAttention(nn.Module):
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
self.config = config
if config.hidden_size % config.num_heads != 0:
raise ValueError(
@ -228,6 +230,8 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
target_dtype = get_target_dtype(query, self) # if the query is in float32, this is the dtype to cast to for FA
attn_output = _flash_attention_forward(
query,
key,
@ -237,6 +241,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
dropout=self.dropout if self.training else 0.0,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
target_dtype=target_dtype,
)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)

View File

@ -280,12 +280,12 @@ class BltSelfAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.layer_idx = layer_idx
self.is_causal = True
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.is_causal = True
def forward(
self,

View File

@ -680,6 +680,7 @@ class KosmosTextAttention(nn.Module):
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.is_causal = True
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(

View File

@ -519,6 +519,7 @@ class MllamaTextSelfAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.layer_idx = layer_idx
self.is_causal = True
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)

View File

@ -52,6 +52,7 @@ if is_torch_flex_attn_available():
if is_flash_attn_available():
from ...integrations.flash_attention import get_target_dtype
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -495,6 +496,8 @@ class StableLmFlashAttention2(StableLmAttention):
dropout_rate = self.attention_dropout.p if self.training else 0.0
target_dtype = get_target_dtype(query_states, self)
attn_output = _flash_attention_forward(
query_states,
key_states,
@ -505,6 +508,7 @@ class StableLmFlashAttention2(StableLmAttention):
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
target_dtype=target_dtype,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()