mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
17 lines
441 B
Python
17 lines
441 B
Python
from transformers.modeling_utils import AttentionInterface
|
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
|
|
|
|
|
def custom_flex(x, **kwargs):
|
|
"""Dummy function."""
|
|
return x
|
|
|
|
|
|
ALL_ATTENTION_FUNCTIONS = AttentionInterface()
|
|
# This indexing statement and associated function should be exported correctly!
|
|
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex
|
|
|
|
|
|
class GlobalIndexingAttention(LlamaAttention):
|
|
pass
|