Compare commits

...

3 Commits

Author SHA1 Message Date
421bf8611a fix 2 2025-02-06 15:58:29 +01:00
6ba13f577b update 2025-02-05 13:44:09 +01:00
82ca6920c6 update 2025-02-05 13:40:59 +01:00
3 changed files with 25 additions and 8 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, TypedDict
import torch
@ -62,3 +62,17 @@ def sdpa_attention_forward(
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
class SdpaAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for sdpa Attention.
Attributes:
is_causal (`bool`, *optional*)
The value for the argument `is_causal` that is passed to `torch.nn.functional.scaled_dot_product_attention`.
An error is thrown if both attention_mask and is_causal are set. If `None`, it is inferred in
`sdpa_attention_forward`.
"""
is_causal: Optional[bool]

View File

@ -30,7 +30,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypedDict, TypeVar, Union
from zipfile import is_zipfile
import torch
@ -48,8 +48,9 @@ from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward, SdpaAttentionKwargs
from .loss.loss_utils import LOSS_MAPPING
from .modeling_flash_attention_utils import FlashAttentionKwargs
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@ -5702,3 +5703,6 @@ ALL_ATTENTION_FUNCTIONS.update(
"sdpa": sdpa_attention_forward,
}
)
AttentionKwargs = Union[FlashAttentionKwargs, SdpaAttentionKwargs]

View File

@ -27,7 +27,6 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -36,7 +35,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionKwargs, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
@ -262,7 +261,7 @@ class LlamaAttention(nn.Module):
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs: Unpack[AttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
@ -326,7 +325,7 @@ class LlamaDecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
@ -528,7 +527,7 @@ class LlamaModel(LlamaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
**flash_attn_kwargs: Unpack[AttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (