Compare commits

...

4 Commits

Author SHA1 Message Date
2b9f5bcd26 add output_attentions 2024-10-22 10:37:36 +02:00
2c26de910c add test 2024-10-21 12:47:28 +02:00
2a22431fc2 add _supports_flex_attention 2024-10-21 12:11:15 +02:00
aed788fe15 initial flex-attention for gpt2 2024-10-21 11:50:10 +02:00
3 changed files with 185 additions and 2 deletions

View File

@ -1570,12 +1570,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2", "flex"]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attention:
message += (
', `"attn_implementation=flex"` (implementation using torch.nn.attention.flex_attention)'
)
raise ValueError(message + ".")
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
@ -1611,6 +1615,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation == "flex":
config._attn_implementation = "flex"
else:
config._attn_implementation = "eager"

View File

@ -19,6 +19,7 @@ import math
import os
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional, Tuple, Union
import torch
@ -26,6 +27,7 @@ import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from ...activations import ACT2FN
from ...generation import GenerationMixin
@ -562,6 +564,140 @@ class GPT2SdpaAttention(GPT2Attention):
return attn_output, present, None
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
return block_mask
class GPT2FlexAttention(GPT2Attention):
"""
GPT2 attention module using torch.nn.attention.flex_attention. This module inherits from
`GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
to adapt to the SDPA API.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if head_mask is not None:
logger.warning_once(
"`GPT2FlexAttention` is used but `torch.nn.attention.flex_attention` does not support "
"`head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
bsz, q_len, _ = hidden_states.size()
# Initial attention projections
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# Optional kv caching
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = None
if use_cache is True:
present = (key, value)
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = (
create_block_mask_cached(
causal_mask,
bsz,
self.num_heads,
q_len,
key.size(-2),
device=query.device,
)
if is_causal
else None
)
attn_output = flex_attention(
query,
key,
value,
block_mask=block_mask,
return_lse=output_attentions,
)
if output_attentions:
attn_output, attn_weights = attn_output
attn_weights_reshaped = attn_weights.reshape(bsz, q_len, self.num_heads * self.head_dim)
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)
# Final projection
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights_reshaped,)
return outputs
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
@ -579,7 +715,12 @@ class GPT2MLP(nn.Module):
return hidden_states
GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
GPT2_ATTENTION_CLASSES = {
"eager": GPT2Attention,
"flash_attention_2": GPT2FlashAttention2,
"sdpa": GPT2SdpaAttention,
"flex": GPT2FlexAttention,
}
class GPT2Block(nn.Module):
@ -676,6 +817,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attention = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@ -939,3 +939,38 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_native, output_fa_2)
self.assertListEqual(output_native, expected_output)
@require_torch_gpu
@slow
def test_flex_attention_generate_padding_left(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = GPT2LMHeadModel.from_pretrained(
"gpt2", device_map={"": 0}, attn_implementation="flex", torch_dtype=torch.float16
)
output_flex = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_flex = tokenizer.batch_decode(output_flex)
expected_output = [
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata",
"Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry",
]
self.assertListEqual(output_native, output_flex)
self.assertListEqual(output_native, expected_output)