Compare commits

...

1 Commits

Author SHA1 Message Date
8ba84b49ee v1 2023-10-23 13:15:33 +02:00

View File

@ -42,9 +42,14 @@ from .configuration_llama import LlamaConfig
if is_flash_attn_2_available():
import flash_attn
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_flash_decoding = hasattr(flash_attn, "flash_attn_with_kvcache")
if _flash_supports_flash_decoding:
from flash_attn import flash_attn_with_kvcache
logger = logging.get_logger(__name__)
@ -451,10 +456,18 @@ class LlamaFlashAttention2(LlamaAttention):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# Used only for `flash_attn_with_kvcache`
key_cache = None
value_cache = None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if not _flash_supports_flash_decoding and not self.training:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
key_cache = past_key_value[0].transpose(1, 2)
value_cache = past_key_value[1].transpose(1, 2)
past_key_value = (key_states, value_states) if use_cache else None
@ -492,7 +505,7 @@ class LlamaFlashAttention2(LlamaAttention):
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate, key_cache=key_cache, value_cache=value_cache
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@ -504,7 +517,7 @@ class LlamaFlashAttention2(LlamaAttention):
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None, key_cache=None, value_cache=None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@ -550,9 +563,18 @@ class LlamaFlashAttention2(LlamaAttention):
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)
if key_cache is None:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)
else:
attn_output = flash_attn_with_kvcache(
query_states,
key_cache,
value_cache,
softmax_scale=softmax_scale,
causal=True
)
return attn_output