Compare commits

...

42 Commits

Author SHA1 Message Date
5060a334de remove layer_idx 2024-12-13 14:07:01 +00:00
caaa5e5508 tgi update 2024-12-12 18:29:26 +00:00
95cb944ee6 be permissive 2024-12-12 11:33:37 +01:00
584b443096 fix unpack imoprt 2024-12-12 10:43:13 +01:00
57eece66af Merge branch 'llama-refactor' of github.com:huggingface/transformers into llama-refactor 2024-12-12 10:36:54 +01:00
9461039d87 nits 2024-12-12 10:36:49 +01:00
f7395cc0cc Merge branch 'main' into llama-refactor 2024-12-12 14:57:58 +05:30
4f36712da1 nit? 2024-12-12 10:24:11 +01:00
2016bc47d0 default init weights 2024-12-12 10:18:38 +01:00
53450ac365 fix 2024-12-12 10:14:20 +01:00
1a5a834f53 fix auto? 2024-12-12 09:53:53 +01:00
3f68c7cf72 9 left! 2024-12-12 09:44:14 +01:00
c224f36d10 fix some tests 2024-12-12 09:39:36 +01:00
725d00caf4 fix some stuff 2024-12-12 09:22:04 +01:00
6028e85990 fixup 2024-12-11 19:50:04 +01:00
7a608da9f8 update 2024-12-11 19:44:29 +01:00
e9d751abaa fix attention_mask 2024-12-11 19:26:39 +01:00
60189825d7 fix! 2024-12-11 19:09:09 +01:00
d9156363bf mm 2024-12-11 18:49:37 +01:00
20c512bc80 clean 2024-12-11 18:07:53 +01:00
7a911efddf update 2024-12-11 17:39:40 +01:00
89d32d6825 fix auto set 2024-12-11 17:30:26 +01:00
3bbae39539 remove tanh 2024-12-11 17:24:56 +01:00
e5d60b4f23 fix 2024-12-11 17:07:20 +01:00
4b9a429a1c style 2024-12-11 16:54:35 +01:00
1ef18f49a9 style 2024-12-11 16:53:02 +01:00
28829d2dd6 there was an issue with tie weight keys 2024-12-11 16:44:04 +01:00
40154815cb revert some stuff 2024-12-11 16:23:18 +01:00
38dd294dd7 fix 2024-12-11 16:10:40 +01:00
1baabd3207 update 2024-12-11 15:13:39 +01:00
dcf7a37ce1 cache concatenates on the wrong axis 2024-12-11 14:38:26 +01:00
f61a5fec41 pass attention 2024-12-11 14:37:12 +01:00
556aa4ec2d updates 2024-12-11 14:35:53 +01:00
341b8ce9fa nits 2024-12-11 14:27:20 +01:00
0418f97553 make auto for causal lm work 2024-12-11 14:18:41 +01:00
39ab8b757b oupts 2024-12-11 13:58:09 +01:00
13a195a7bb _output_embedding and _input_embeding 2024-12-11 13:53:35 +01:00
893ef382c4 nits 2024-12-11 13:47:59 +01:00
4e681b9c72 nits 2024-12-11 13:38:19 +01:00
0384db9c0c more refactoring 2024-12-11 12:39:07 +01:00
f446bd4c00 only change lLlama 2024-12-11 12:20:51 +01:00
f14637a7b5 refactor LlamaAttention 2024-11-28 08:01:54 +01:00
16 changed files with 887 additions and 1253 deletions

View File

@ -191,6 +191,7 @@ _import_structure = {
"AutoImageProcessor",
"AutoProcessor",
"AutoTokenizer",
"AutoForCausalLM",
],
"models.autoformer": ["AutoformerConfig"],
"models.bark": [
@ -2611,10 +2612,6 @@ else:
)
_import_structure["models.llama"].extend(
[
"LlamaForCausalLM",
"LlamaForQuestionAnswering",
"LlamaForSequenceClassification",
"LlamaForTokenClassification",
"LlamaModel",
"LlamaPreTrainedModel",
]
@ -5084,6 +5081,7 @@ if TYPE_CHECKING:
TOKENIZER_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoForCausalLM,
AutoImageProcessor,
AutoProcessor,
AutoTokenizer,
@ -6469,6 +6467,7 @@ if TYPE_CHECKING:
AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead,
)
from .models.auto.modeling_task import AutoForCausalLM
from .models.autoformer import (
AutoformerForPrediction,
AutoformerModel,
@ -7336,10 +7335,6 @@ if TYPE_CHECKING:
LiltPreTrainedModel,
)
from .models.llama import (
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
)

View File

@ -0,0 +1,41 @@
import torch
from ..modeling_flash_attention_utils import _flash_attention_forward
def flash_attention_forward(
config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
):
if attention_mask is not None:
seq_len = attention_mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
else:
seq_len = query.shape[1]
# Re-transpose them
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
dropout_rate = config.attention_dropout if training else 0.0
input_dtype = query.dtype
if input_dtype == torch.float32:
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
seq_len,
config=config,
dropout=dropout_rate,
layer_idx=layer_idx,
**kwargs,
)
return attn_output, None

View File

@ -0,0 +1,28 @@
from ..utils import is_torch_greater_or_equal
if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention
def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs):
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
def causal_mod(score, b, h, q_idx, kv_idx):
if causal_mask is not None:
score += causal_mask[b][0][q_idx][kv_idx]
return score
attn_output, attention_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
enable_gqa=True,
scale=module.scaling,
return_lse=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attention_weights

View File

@ -0,0 +1,39 @@
import torch
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=module.config.attention_dropout if module.training else 0.0,
is_causal=is_causal,
scale=module.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None

View File

@ -276,7 +276,7 @@ def _flash_attention_forward(
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.mistral.modeling_mistral.MistralFlashAttention2.__init__.
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).

View File

@ -45,6 +45,9 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Conv1D,
@ -1290,6 +1293,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None
_output_embedding = None
_input_embedding = None
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
@ -1487,7 +1493,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
raise ValueError(message + ".")
if config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
pass
else:
raise ValueError(message + ".")
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
requested_attn_implementation = config._attn_implementation_internal
@ -1525,10 +1534,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config,
hard_check_only=False if requested_attn_implementation is None else True,
)
# config = cls._check_and_enable_sdpa(
# config,
# hard_check_only=False if requested_attn_implementation is None else True,
# )
if (
torch.version.hip is not None
@ -1539,6 +1548,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
pass
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
@ -1801,7 +1812,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
return getattr(self, self._input_embedding)
def set_input_embeddings(self, value: nn.Module):
"""
@ -1813,8 +1824,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
elif self._input_embedding is not None:
setattr(self, self._input_embedding, value)
else:
raise NotImplementedError
raise ValueError("No input embedding")
def get_output_embeddings(self) -> nn.Module:
"""
@ -1823,7 +1836,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Returns:
`nn.Module`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings
if self._output_embedding is not None:
return getattr(self, self._output_embedding, None)
else:
return None
def set_output_embeddings(self, value: nn.Module):
"""
Set model's input embeddings.
Args:
value (`nn.Module`): A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_output_embeddings(value)
elif self._output_embedding is not None:
setattr(self, self._output_embedding, value)
else:
raise ValueError()
def _init_weights(self, module):
"""
@ -1832,7 +1863,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
using `from_pretrained`. Any attempt to initialize outside of this function
will be useless as the torch.nn.init function are all replaced with skip.
"""
pass
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _initialize_weights(self, module):
"""
@ -2509,91 +2548,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.
for layer in list(self.modules()):
if isinstance(layer, GradientCheckpointLayer):
layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
is_gradient_checkpointing_set = False
# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
for module in self.modules():
if hasattr(module, "gradient_checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
" `gradient_checkpointing` to modules of the model that uses checkpointing."
)
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self.supports_gradient_checkpointing:
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=False)
else:
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
self.apply(partial(self._set_gradient_checkpointing, value=False))
if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()
@property
def gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
for layer in list(self.modules()):
if isinstance(layer, GradientCheckpointLayer):
return layer.gradient_checkpointing
return False
@property
def is_gradient_checkpointing(self) -> bool:
@ -5626,3 +5590,139 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
files_content[filename].append(device_map[weight_name])
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {}
ALL_ATTENTION_FUNCTIONS.update(
{
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"sdpa": sdpa_attention_forward,
}
)
class GradientCheckpointLayer(torch.nn.Module):
def __init__(self, *args, **kwargs):
self.gradient_checkpointing = False
super().__init__( *args, **kwargs)
def __call__(self, *args, **kwargs):
"""
Adjust the behavior of the inherited class by overriding `__call__`.
Automatically handles gradient checkpointing based on flags in the provided arguments.
"""
# Extract necessary flags and arguments
gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr(
self, "gradient_checkpointing", False
)
training = self.training
if gradient_checkpointing and training:
# Use gradient checkpointing
return self._apply_gradient_checkpointing(*args, **kwargs)
else:
# Default behavior: call the original `forward` method
return self.forward(*args, **kwargs)
def _apply_gradient_checkpointing(self, *args, **kwargs):
"""
Apply gradient checkpointing using the appropriate function.
By default, uses `torch.utils.checkpoint.checkpoint`.
"""
# Assume `self.forward` is compatible with checkpointing
def wrapped_forward():
return self.forward(*args, **kwargs)
return self._gradient_checkpointing_func(wrapped_forward)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
self.gradient_checkpointing = True
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
is_gradient_checkpointing_set = False
# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
for module in self.modules():
if hasattr(module, "gradient_checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
" `gradient_checkpointing` to modules of the model that uses checkpointing."
)
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self.supports_gradient_checkpointing:
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=False)
else:
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
self.apply(partial(self._set_gradient_checkpointing, value=False))
if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()

View File

@ -122,6 +122,12 @@ else:
"AutoModelForZeroShotObjectDetection",
"AutoModelForImageTextToText",
]
_import_structure["modeling_task"] = [
"AutoForCausalLM",
"AutoForSequenceClassification",
"AutoForQuestionAnswering",
"AutoForTokenClassification",
]
try:
if not is_tf_available():
@ -311,6 +317,12 @@ if TYPE_CHECKING:
AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead,
)
from .modeling_task import (
AutoForCausalLM,
AutoForQuestionAnswering,
AutoForSequenceClassification,
AutoForTokenClassification,
)
try:
if not is_tf_available():

View File

@ -438,7 +438,9 @@ class _BaseAutoModelClass:
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)
else:
model_class = cls._model_mapping["auto"]
return model_class._from_config(config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
@ -564,6 +566,11 @@ class _BaseAutoModelClass:
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
else:
model_class = cls._model_mapping[PretrainedConfig]
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."

View File

@ -39,6 +39,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("aria_text", "AriaTextConfig"),
("audio-spectrogram-transformer", "ASTConfig"),
("autoformer", "AutoformerConfig"),
("auto", "PretrainedConfig"),
("bark", "BarkConfig"),
("bart", "BartConfig"),
("beit", "BeitConfig"),
@ -335,6 +336,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("aria_text", "AriaText"),
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
("autoformer", "Autoformer"),
("auto", "Auto"),
("bark", "Bark"),
("bart", "BART"),
("barthez", "BARThez"),
@ -912,6 +914,8 @@ class AutoConfig:
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
model_type = "auto"
def __init__(self):
raise EnvironmentError(
"AutoConfig is designed to be instantiated "

View File

@ -468,6 +468,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("auto", "AutoForCausalLM"),
("aria_text", "AriaTextForCausalLM"),
("bart", "BartForCausalLM"),
("bert", "BertLMHeadModel"),
@ -479,7 +480,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForCausalLM"),
("code_llama", "LlamaForCausalLM"),
("codegen", "CodeGenForCausalLM"),
("cohere", "CohereForCausalLM"),
("cpmant", "CpmAntForCausalLM"),
@ -506,7 +506,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("granitemoe", "GraniteMoeForCausalLM"),
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("marian", "MarianForCausalLM"),
@ -931,6 +930,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("albert", "AlbertForSequenceClassification"),
("auto", "AutoForSequenceClassification"),
("bart", "BartForSequenceClassification"),
("bert", "BertForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
@ -939,7 +939,6 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bloom", "BloomForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("code_llama", "LlamaForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
("ctrl", "CTRLForSequenceClassification"),
("data2vec-text", "Data2VecTextForSequenceClassification"),
@ -971,7 +970,6 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("lilt", "LiltForSequenceClassification"),
("llama", "LlamaForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"),
("markuplm", "MarkupLMForSequenceClassification"),
@ -1028,6 +1026,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("albert", "AlbertForQuestionAnswering"),
("auto", "AutoForQuestionAnswering"),
("bart", "BartForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
("big_bird", "BigBirdForQuestionAnswering"),
@ -1056,7 +1055,6 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("lilt", "LiltForQuestionAnswering"),
("llama", "LlamaForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("luke", "LukeForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
@ -1125,6 +1123,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("albert", "AlbertForTokenClassification"),
("auto", "AutoForTokenClassification"),
("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("biogpt", "BioGptForTokenClassification"),
@ -1158,7 +1157,6 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
("lilt", "LiltForTokenClassification"),
("llama", "LlamaForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"),
("markuplm", "MarkupLMForTokenClassification"),

View File

@ -0,0 +1,298 @@
from typing import List, Optional, Tuple, Union
from ...processing_utils import Unpack
import torch
import torch.nn as nn
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_outputs import (
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils.generic import KwargsForCausalLM, validate_config_kwargs
from ..auto import AutoConfig, AutoModel
class AutoForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_output_embedding = "lm_head"
_no_split_modules = []
_supports_cache_class = True
config_class = AutoConfig
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_config(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module):
self.model.set_input_embeddings(value)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
return_dict=True,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
output = CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return output if return_dict else output.to_tuple()
class AutoForSequenceClassification(PreTrainedModel):
config_class = AutoConfig
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = AutoModel.from_config(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
class AutoForQuestionAnswering(PreTrainedModel):
base_model_prefix = "transformer"
config_class = AutoConfig
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
def __init__(self, config):
super().__init__(config)
self.transformer = AutoModel.from_config(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.transformer.embed_tokens
def set_input_embeddings(self, value):
self.transformer.embed_tokens = value
@validate_config_kwargs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class AutoForTokenClassification(PreTrainedModel):
config_class = AutoConfig
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = AutoModel.from_config(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict=None,
**kwargs,
) -> Union[Tuple, TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -50,12 +50,8 @@ except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_llama"] = [
"LlamaForCausalLM",
"LlamaModel",
"LlamaPreTrainedModel",
"LlamaForSequenceClassification",
"LlamaForQuestionAnswering",
"LlamaForTokenClassification",
]
try:
@ -93,10 +89,6 @@ if TYPE_CHECKING:
pass
else:
from .modeling_llama import (
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
)

File diff suppressed because it is too large Load Diff

View File

@ -26,9 +26,10 @@ from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Union, Dict
import numpy as np
import torch
from packaging import version
from .import_utils import (
@ -871,6 +872,54 @@ class LossKwargs(TypedDict, total=False):
num_items_in_batch: Optional[int]
class KwargsForCausalLM(LossKwargs):
input_ids: torch.LongTensor = None
attention_mask: Optional[torch.Tensor] = None
position_ids: Optional[torch.LongTensor] = None
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None
inputs_embeds: Optional[torch.FloatTensor] = None
labels: Optional[torch.LongTensor] = None
use_cache: Optional[bool] = None
output_attentions: Optional[bool] = None
output_hidden_states: Optional[bool] = None
return_dict: Optional[bool] = None
cache_position: Optional[torch.LongTensor] = None
num_logits_to_keep: int = 0
def validate_config_kwargs(func):
"""
A decorator to validate and initialize kwargs based on a config object.
"""
@wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
# Default values from the config
default_kwargs = {
"output_attentions": self.config.output_attentions,
"output_hidden_states": self.config.output_hidden_states,
"use_cache": self.config.use_cache,
"return_dict": self.config.use_return_dict,
}
# Merge provided kwargs with defaults
validated_kwargs = {**default_kwargs, **kwargs}
# # Validate kwargs against TypedDict
# for key in validated_kwargs:
# if key not in KwargsForCausalLM.__annotations__:
# raise ValueError(f"Invalid keyword argument: {key}")
if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]:
validated_kwargs["use_cache"] = False
# Pass the validated kwargs to the function
return func(*args, **validated_kwargs)
return wrapper
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
"""Checks whether a config dict is a timm config dict."""
return "pretrained_cfg" in config_dict

View File

@ -23,6 +23,12 @@ from parameterized import parameterized
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.models.auto import (
AutoForCausalLM,
AutoForQuestionAnswering,
AutoForSequenceClassification,
AutoForTokenClassification,
)
from transformers.testing_utils import (
cleanup,
require_flash_attn,
@ -44,14 +50,9 @@ if is_torch_available():
import torch
from transformers import (
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaTokenizer,
)
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
class LlamaModelTester:
@ -197,7 +198,7 @@ class LlamaModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
model = LlamaForCausalLM(config=config)
model = AutoForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
@ -217,7 +218,7 @@ class LlamaModelTester:
):
config.is_decoder = True
config.add_cross_attention = True
model = LlamaForCausalLM(config=config)
model = AutoForCausalLM(config=config)
model.to(torch_device)
model.eval()
@ -285,23 +286,23 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
all_model_classes = (
(
LlamaModel,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForQuestionAnswering,
LlamaForTokenClassification,
AutoForCausalLM,
AutoForSequenceClassification,
AutoForQuestionAnswering,
AutoForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
all_generative_model_classes = (AutoForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": LlamaModel,
"text-classification": LlamaForSequenceClassification,
"text-generation": LlamaForCausalLM,
"zero-shot": LlamaForSequenceClassification,
"question-answering": LlamaForQuestionAnswering,
"token-classification": LlamaForTokenClassification,
"text-classification": AutoForSequenceClassification,
"text-generation": AutoForCausalLM,
"zero-shot": AutoForSequenceClassification,
"question-answering": AutoForQuestionAnswering,
"token-classification": AutoForTokenClassification,
}
if is_torch_available()
else {}
@ -315,7 +316,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
_torch_compile_train_cls = AutoForCausalLM if is_torch_available() else None
def setUp(self):
self.model_tester = LlamaModelTester(self)
@ -340,7 +341,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = LlamaForSequenceClassification(config)
model = AutoForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
@ -353,7 +354,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = LlamaForSequenceClassification(config)
model = AutoForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
@ -368,7 +369,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = LlamaForSequenceClassification(config)
model = AutoForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
@ -380,7 +381,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = LlamaForTokenClassification(config=config)
model = AutoForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
@ -424,71 +425,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
original_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
# Sanity check Yarn RoPE scaling
# Scaling should be over the entire input
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
def test_model_loading_old_rope_configs(self):
def _reinitialize_config(base_config, new_kwargs):
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
@ -499,17 +435,17 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# from untouched config -> ✅
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
original_model = LlamaForCausalLM(base_config).to(torch_device)
original_model = AutoForCausalLM(base_config).to(torch_device)
original_model(**model_inputs)
# from a config with the expected rope configuration -> ✅
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model = AutoForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model = AutoForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
@ -518,13 +454,13 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
self.assertTrue(config.rope_scaling["type"] == "linear")
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
original_model = LlamaForCausalLM(config).to(torch_device)
original_model = AutoForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model = AutoForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("factor field", logs.output[0])
@ -534,7 +470,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
config = _reinitialize_config(
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
)
original_model = LlamaForCausalLM(config).to(torch_device)
original_model = AutoForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("Unrecognized keys", logs.output[0])
@ -557,7 +493,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
model = model_class(config)
model.save_pretrained(tmp_dir)
new_model = LlamaForCausalLM.from_pretrained(
new_model = AutoForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")
@ -608,7 +544,7 @@ class LlamaIntegrationTest(unittest.TestCase):
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = LlamaForCausalLM.from_pretrained(
model = AutoForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
)
input_text = ["Tell me about the french revolution."]
@ -623,7 +559,7 @@ class LlamaIntegrationTest(unittest.TestCase):
def test_model_7b_logits_bf16(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained(
model = AutoForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
)
@ -667,7 +603,7 @@ class LlamaIntegrationTest(unittest.TestCase):
def test_model_7b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained(
model = AutoForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
@ -717,7 +653,7 @@ class LlamaIntegrationTest(unittest.TestCase):
)
prompt = "Simply put, the theory of relativity states that "
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = LlamaForCausalLM.from_pretrained(
model = AutoForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
)
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@ -754,7 +690,7 @@ class LlamaIntegrationTest(unittest.TestCase):
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained(
model = AutoForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
@ -819,7 +755,7 @@ class LlamaIntegrationTest(unittest.TestCase):
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
model = LlamaForCausalLM.from_pretrained(
model = AutoForCausalLM.from_pretrained(
llama_model_ckp,
device_map=device,
torch_dtype=dtype,
@ -859,7 +795,7 @@ class Mask4DTestHard(unittest.TestCase):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
self.model = AutoForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self):
template = "my favorite {}"

View File

@ -943,78 +943,79 @@ class ModelTesterMixin:
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
with self.subTest(model_class.__name__):
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = 5
if self.is_encoder_decoder:
correct_outlen = 5
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class.__name__ in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
]:
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class.__name__ in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
]:
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
self.assertEqual(out_len, correct_outlen)
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
@ -1022,30 +1023,31 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
with self.subTest(model_class.__name__):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
@slow
def test_torchscript_simple(self):