mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Fix fp32_ln for various models (#41605)
* Add is_causal to KosmosTextAttention * Move get target_dtype to be imported elsewhere * Fix fp32 flash attention bug in bark * Fix is_causal in mllama * Fix fp32 issue on StableLM * Fix repo-consistency
This commit is contained in:
@ -11,6 +11,19 @@ logger = logging.get_logger(__name__)
|
||||
_use_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
|
||||
def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype:
|
||||
"""If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise."""
|
||||
if query.dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
return torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(module.config, "_pre_quantization_dtype"):
|
||||
return module.config._pre_quantization_dtype
|
||||
else:
|
||||
return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
|
||||
return None
|
||||
|
||||
|
||||
def flash_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -48,15 +61,7 @@ def flash_attention_forward(
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (usually our RMSNorm modules handle it correctly)
|
||||
target_dtype = None
|
||||
if query.dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(module.config, "_pre_quantization_dtype"):
|
||||
target_dtype = module.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
|
||||
target_dtype = get_target_dtype(query, module)
|
||||
|
||||
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
|
||||
is_causal = kwargs.pop("is_causal", None)
|
||||
|
@ -57,6 +57,7 @@ from .generation_configuration_bark import (
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...integrations.flash_attention import get_target_dtype
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
@ -78,6 +79,7 @@ class BarkSelfAttention(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.config = config
|
||||
|
||||
if config.hidden_size % config.num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -228,6 +230,8 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
|
||||
if past_key_values is not None:
|
||||
key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
|
||||
|
||||
target_dtype = get_target_dtype(query, self) # if the query is in float32, this is the dtype to cast to for FA
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
@ -237,6 +241,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
|
||||
dropout=self.dropout if self.training else 0.0,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
target_dtype=target_dtype,
|
||||
)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
@ -280,12 +280,12 @@ class BltSelfAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = config.rope_theta
|
||||
self.layer_idx = layer_idx
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.is_causal = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -680,6 +680,7 @@ class KosmosTextAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
|
@ -519,6 +519,7 @@ class MllamaTextSelfAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = config.rope_theta
|
||||
self.layer_idx = layer_idx
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
|
@ -52,6 +52,7 @@ if is_torch_flex_attn_available():
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...integrations.flash_attention import get_target_dtype
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
@ -495,6 +496,8 @@ class StableLmFlashAttention2(StableLmAttention):
|
||||
|
||||
dropout_rate = self.attention_dropout.p if self.training else 0.0
|
||||
|
||||
target_dtype = get_target_dtype(query_states, self)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -505,6 +508,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
||||
dropout=dropout_rate,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
target_dtype=target_dtype,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
|
Reference in New Issue
Block a user