mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
42 Commits
v4.49.0
...
llama-refa
Author | SHA1 | Date | |
---|---|---|---|
5060a334de | |||
caaa5e5508 | |||
95cb944ee6 | |||
584b443096 | |||
57eece66af | |||
9461039d87 | |||
f7395cc0cc | |||
4f36712da1 | |||
2016bc47d0 | |||
53450ac365 | |||
1a5a834f53 | |||
3f68c7cf72 | |||
c224f36d10 | |||
725d00caf4 | |||
6028e85990 | |||
7a608da9f8 | |||
e9d751abaa | |||
60189825d7 | |||
d9156363bf | |||
20c512bc80 | |||
7a911efddf | |||
89d32d6825 | |||
3bbae39539 | |||
e5d60b4f23 | |||
4b9a429a1c | |||
1ef18f49a9 | |||
28829d2dd6 | |||
40154815cb | |||
38dd294dd7 | |||
1baabd3207 | |||
dcf7a37ce1 | |||
f61a5fec41 | |||
556aa4ec2d | |||
341b8ce9fa | |||
0418f97553 | |||
39ab8b757b | |||
13a195a7bb | |||
893ef382c4 | |||
4e681b9c72 | |||
0384db9c0c | |||
f446bd4c00 | |||
f14637a7b5 |
@ -191,6 +191,7 @@ _import_structure = {
|
||||
"AutoImageProcessor",
|
||||
"AutoProcessor",
|
||||
"AutoTokenizer",
|
||||
"AutoForCausalLM",
|
||||
],
|
||||
"models.autoformer": ["AutoformerConfig"],
|
||||
"models.bark": [
|
||||
@ -2611,10 +2612,6 @@ else:
|
||||
)
|
||||
_import_structure["models.llama"].extend(
|
||||
[
|
||||
"LlamaForCausalLM",
|
||||
"LlamaForQuestionAnswering",
|
||||
"LlamaForSequenceClassification",
|
||||
"LlamaForTokenClassification",
|
||||
"LlamaModel",
|
||||
"LlamaPreTrainedModel",
|
||||
]
|
||||
@ -5084,6 +5081,7 @@ if TYPE_CHECKING:
|
||||
TOKENIZER_MAPPING,
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoForCausalLM,
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
@ -6469,6 +6467,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
from .models.auto.modeling_task import AutoForCausalLM
|
||||
from .models.autoformer import (
|
||||
AutoformerForPrediction,
|
||||
AutoformerModel,
|
||||
@ -7336,10 +7335,6 @@ if TYPE_CHECKING:
|
||||
LiltPreTrainedModel,
|
||||
)
|
||||
from .models.llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
|
41
src/transformers/integrations/flash_attention.py
Normal file
41
src/transformers/integrations/flash_attention.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from ..modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
def flash_attention_forward(
|
||||
config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
|
||||
):
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
value = value[:, :, :seq_len]
|
||||
else:
|
||||
seq_len = query.shape[1]
|
||||
|
||||
# Re-transpose them
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
dropout_rate = config.attention_dropout if training else 0.0
|
||||
|
||||
input_dtype = query.dtype
|
||||
if input_dtype == torch.float32:
|
||||
query = query.to(target_dtype)
|
||||
key = key.to(target_dtype)
|
||||
value = value.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
seq_len,
|
||||
config=config,
|
||||
dropout=dropout_rate,
|
||||
layer_idx=layer_idx,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return attn_output, None
|
28
src/transformers/integrations/flex_attention.py
Normal file
28
src/transformers/integrations/flex_attention.py
Normal file
@ -0,0 +1,28 @@
|
||||
from ..utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
if is_torch_greater_or_equal("2.5"):
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
|
||||
|
||||
def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs):
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
def causal_mod(score, b, h, q_idx, kv_idx):
|
||||
if causal_mask is not None:
|
||||
score += causal_mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output, attention_weights = flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=causal_mod,
|
||||
enable_gqa=True,
|
||||
scale=module.scaling,
|
||||
return_lse=True,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attention_weights
|
39
src/transformers/integrations/sdpa_attention.py
Normal file
39
src/transformers/integrations/sdpa_attention.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=module.config.attention_dropout if module.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=module.scaling,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, None
|
@ -276,7 +276,7 @@ def _flash_attention_forward(
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.mistral.modeling_mistral.MistralFlashAttention2.__init__.
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||
|
@ -45,6 +45,9 @@ from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
@ -1290,6 +1293,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# `config.base_model_tp_plan` during `post_init`.
|
||||
_tp_plan = None
|
||||
|
||||
_output_embedding = None
|
||||
_input_embedding = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
@ -1487,7 +1493,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
message += (
|
||||
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
)
|
||||
raise ValueError(message + ".")
|
||||
if config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(message + ".")
|
||||
|
||||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
||||
requested_attn_implementation = config._attn_implementation_internal
|
||||
@ -1525,10 +1534,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
|
||||
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
|
||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||
config = cls._check_and_enable_sdpa(
|
||||
config,
|
||||
hard_check_only=False if requested_attn_implementation is None else True,
|
||||
)
|
||||
# config = cls._check_and_enable_sdpa(
|
||||
# config,
|
||||
# hard_check_only=False if requested_attn_implementation is None else True,
|
||||
# )
|
||||
|
||||
if (
|
||||
torch.version.hip is not None
|
||||
@ -1539,6 +1548,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
elif config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
|
||||
pass
|
||||
elif isinstance(requested_attn_implementation, dict):
|
||||
config._attn_implementation = None
|
||||
else:
|
||||
@ -1801,7 +1812,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return getattr(self, self._input_embedding)
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
"""
|
||||
@ -1813,8 +1824,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
elif self._input_embedding is not None:
|
||||
setattr(self, self._input_embedding, value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise ValueError("No input embedding")
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
@ -1823,7 +1836,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
Returns:
|
||||
`nn.Module`: A torch module mapping hidden states to vocabulary.
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
if self._output_embedding is not None:
|
||||
return getattr(self, self._output_embedding, None)
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_output_embeddings(self, value: nn.Module):
|
||||
"""
|
||||
Set model's input embeddings.
|
||||
|
||||
Args:
|
||||
value (`nn.Module`): A module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
base_model.set_output_embeddings(value)
|
||||
elif self._output_embedding is not None:
|
||||
setattr(self, self._output_embedding, value)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""
|
||||
@ -1832,7 +1863,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
using `from_pretrained`. Any attempt to initialize outside of this function
|
||||
will be useless as the torch.nn.init function are all replaced with skip.
|
||||
"""
|
||||
pass
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _initialize_weights(self, module):
|
||||
"""
|
||||
@ -2509,91 +2548,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
for layer in list(self.modules()):
|
||||
if isinstance(layer, GradientCheckpointLayer):
|
||||
layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
|
||||
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
|
||||
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
|
||||
Args:
|
||||
gradient_checkpointing_kwargs (dict, *optional*):
|
||||
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
||||
"""
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
|
||||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` method
|
||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
else:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
logger.warning(
|
||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||
)
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||
# the gradients to make sure the gradient flows.
|
||||
self.enable_input_require_grads()
|
||||
|
||||
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
|
||||
is_gradient_checkpointing_set = False
|
||||
|
||||
# Apply it on the top-level module in case the top-level modules supports it
|
||||
# for example, LongT5Stack inherits from `PreTrainedModel`.
|
||||
if hasattr(self, "gradient_checkpointing"):
|
||||
self._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
self.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
if not is_gradient_checkpointing_set:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
|
||||
" `gradient_checkpointing` to modules of the model that uses checkpointing."
|
||||
)
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""
|
||||
Deactivates gradient checkpointing for the current model.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
if self.supports_gradient_checkpointing:
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
|
||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(enable=False)
|
||||
else:
|
||||
logger.warning(
|
||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||
)
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
self.disable_input_require_grads()
|
||||
@property
|
||||
def gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
for layer in list(self.modules()):
|
||||
if isinstance(layer, GradientCheckpointLayer):
|
||||
return layer.gradient_checkpointing
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
@ -5626,3 +5590,139 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
||||
files_content[filename].append(device_map[weight_name])
|
||||
|
||||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {}
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS.update(
|
||||
{
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class GradientCheckpointLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.gradient_checkpointing = False
|
||||
super().__init__( *args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Adjust the behavior of the inherited class by overriding `__call__`.
|
||||
|
||||
Automatically handles gradient checkpointing based on flags in the provided arguments.
|
||||
"""
|
||||
# Extract necessary flags and arguments
|
||||
gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr(
|
||||
self, "gradient_checkpointing", False
|
||||
)
|
||||
training = self.training
|
||||
|
||||
if gradient_checkpointing and training:
|
||||
# Use gradient checkpointing
|
||||
return self._apply_gradient_checkpointing(*args, **kwargs)
|
||||
else:
|
||||
# Default behavior: call the original `forward` method
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _apply_gradient_checkpointing(self, *args, **kwargs):
|
||||
"""
|
||||
Apply gradient checkpointing using the appropriate function.
|
||||
|
||||
By default, uses `torch.utils.checkpoint.checkpoint`.
|
||||
"""
|
||||
|
||||
# Assume `self.forward` is compatible with checkpointing
|
||||
def wrapped_forward():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
return self._gradient_checkpointing_func(wrapped_forward)
|
||||
|
||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
|
||||
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
|
||||
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
|
||||
Args:
|
||||
gradient_checkpointing_kwargs (dict, *optional*):
|
||||
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
||||
"""
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` method
|
||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
else:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
logger.warning(
|
||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||
)
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||
# the gradients to make sure the gradient flows.
|
||||
self.enable_input_require_grads()
|
||||
|
||||
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
|
||||
is_gradient_checkpointing_set = False
|
||||
|
||||
# Apply it on the top-level module in case the top-level modules supports it
|
||||
# for example, LongT5Stack inherits from `PreTrainedModel`.
|
||||
if hasattr(self, "gradient_checkpointing"):
|
||||
self._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
self.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
if not is_gradient_checkpointing_set:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
|
||||
" `gradient_checkpointing` to modules of the model that uses checkpointing."
|
||||
)
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""
|
||||
Deactivates gradient checkpointing for the current model.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
if self.supports_gradient_checkpointing:
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
|
||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(enable=False)
|
||||
else:
|
||||
logger.warning(
|
||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||
)
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
self.disable_input_require_grads()
|
||||
|
@ -122,6 +122,12 @@ else:
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
"AutoModelForImageTextToText",
|
||||
]
|
||||
_import_structure["modeling_task"] = [
|
||||
"AutoForCausalLM",
|
||||
"AutoForSequenceClassification",
|
||||
"AutoForQuestionAnswering",
|
||||
"AutoForTokenClassification",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
@ -311,6 +317,12 @@ if TYPE_CHECKING:
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
from .modeling_task import (
|
||||
AutoForCausalLM,
|
||||
AutoForQuestionAnswering,
|
||||
AutoForSequenceClassification,
|
||||
AutoForTokenClassification,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
|
@ -438,7 +438,9 @@ class _BaseAutoModelClass:
|
||||
elif type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
return model_class._from_config(config, **kwargs)
|
||||
|
||||
else:
|
||||
model_class = cls._model_mapping["auto"]
|
||||
return model_class._from_config(config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
@ -564,6 +566,11 @@ class _BaseAutoModelClass:
|
||||
return model_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||
)
|
||||
else:
|
||||
model_class = cls._model_mapping[PretrainedConfig]
|
||||
return model_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
|
@ -39,6 +39,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("aria_text", "AriaTextConfig"),
|
||||
("audio-spectrogram-transformer", "ASTConfig"),
|
||||
("autoformer", "AutoformerConfig"),
|
||||
("auto", "PretrainedConfig"),
|
||||
("bark", "BarkConfig"),
|
||||
("bart", "BartConfig"),
|
||||
("beit", "BeitConfig"),
|
||||
@ -335,6 +336,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("aria_text", "AriaText"),
|
||||
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
||||
("autoformer", "Autoformer"),
|
||||
("auto", "Auto"),
|
||||
("bark", "Bark"),
|
||||
("bart", "BART"),
|
||||
("barthez", "BARThez"),
|
||||
@ -912,6 +914,8 @@ class AutoConfig:
|
||||
This class cannot be instantiated directly using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
model_type = "auto"
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoConfig is designed to be instantiated "
|
||||
|
@ -468,6 +468,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
("auto", "AutoForCausalLM"),
|
||||
("aria_text", "AriaTextForCausalLM"),
|
||||
("bart", "BartForCausalLM"),
|
||||
("bert", "BertLMHeadModel"),
|
||||
@ -479,7 +480,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("blenderbot-small", "BlenderbotSmallForCausalLM"),
|
||||
("bloom", "BloomForCausalLM"),
|
||||
("camembert", "CamembertForCausalLM"),
|
||||
("code_llama", "LlamaForCausalLM"),
|
||||
("codegen", "CodeGenForCausalLM"),
|
||||
("cohere", "CohereForCausalLM"),
|
||||
("cpmant", "CpmAntForCausalLM"),
|
||||
@ -506,7 +506,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("granitemoe", "GraniteMoeForCausalLM"),
|
||||
("jamba", "JambaForCausalLM"),
|
||||
("jetmoe", "JetMoeForCausalLM"),
|
||||
("llama", "LlamaForCausalLM"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
@ -931,6 +930,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
("albert", "AlbertForSequenceClassification"),
|
||||
("auto", "AutoForSequenceClassification"),
|
||||
("bart", "BartForSequenceClassification"),
|
||||
("bert", "BertForSequenceClassification"),
|
||||
("big_bird", "BigBirdForSequenceClassification"),
|
||||
@ -939,7 +939,6 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("bloom", "BloomForSequenceClassification"),
|
||||
("camembert", "CamembertForSequenceClassification"),
|
||||
("canine", "CanineForSequenceClassification"),
|
||||
("code_llama", "LlamaForSequenceClassification"),
|
||||
("convbert", "ConvBertForSequenceClassification"),
|
||||
("ctrl", "CTRLForSequenceClassification"),
|
||||
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
||||
@ -971,7 +970,6 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
|
||||
("led", "LEDForSequenceClassification"),
|
||||
("lilt", "LiltForSequenceClassification"),
|
||||
("llama", "LlamaForSequenceClassification"),
|
||||
("longformer", "LongformerForSequenceClassification"),
|
||||
("luke", "LukeForSequenceClassification"),
|
||||
("markuplm", "MarkupLMForSequenceClassification"),
|
||||
@ -1028,6 +1026,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
("albert", "AlbertForQuestionAnswering"),
|
||||
("auto", "AutoForQuestionAnswering"),
|
||||
("bart", "BartForQuestionAnswering"),
|
||||
("bert", "BertForQuestionAnswering"),
|
||||
("big_bird", "BigBirdForQuestionAnswering"),
|
||||
@ -1056,7 +1055,6 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
||||
("led", "LEDForQuestionAnswering"),
|
||||
("lilt", "LiltForQuestionAnswering"),
|
||||
("llama", "LlamaForQuestionAnswering"),
|
||||
("longformer", "LongformerForQuestionAnswering"),
|
||||
("luke", "LukeForQuestionAnswering"),
|
||||
("lxmert", "LxmertForQuestionAnswering"),
|
||||
@ -1125,6 +1123,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
("albert", "AlbertForTokenClassification"),
|
||||
("auto", "AutoForTokenClassification"),
|
||||
("bert", "BertForTokenClassification"),
|
||||
("big_bird", "BigBirdForTokenClassification"),
|
||||
("biogpt", "BioGptForTokenClassification"),
|
||||
@ -1158,7 +1157,6 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
||||
("lilt", "LiltForTokenClassification"),
|
||||
("llama", "LlamaForTokenClassification"),
|
||||
("longformer", "LongformerForTokenClassification"),
|
||||
("luke", "LukeForTokenClassification"),
|
||||
("markuplm", "MarkupLMForTokenClassification"),
|
||||
|
298
src/transformers/models/auto/modeling_task.py
Normal file
298
src/transformers/models/auto/modeling_task.py
Normal file
@ -0,0 +1,298 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from ...processing_utils import Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import (
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils.generic import KwargsForCausalLM, validate_config_kwargs
|
||||
from ..auto import AutoConfig, AutoModel
|
||||
|
||||
|
||||
class AutoForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_output_embedding = "lm_head"
|
||||
_no_split_modules = []
|
||||
_supports_cache_class = True
|
||||
config_class = AutoConfig
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = AutoModel.from_config(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
output = CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class AutoForSequenceClassification(PreTrainedModel):
|
||||
config_class = AutoConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = AutoModel.from_config(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
return_dict=True,
|
||||
**kwargs
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class AutoForQuestionAnswering(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
config_class = AutoConfig
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = AutoModel.from_config(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.transformer.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.transformer.embed_tokens = value
|
||||
|
||||
@validate_config_kwargs
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
return_dict=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class AutoForTokenClassification(PreTrainedModel):
|
||||
config_class = AutoConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = AutoModel.from_config(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, TokenClassifierOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
@ -50,12 +50,8 @@ except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_llama"] = [
|
||||
"LlamaForCausalLM",
|
||||
"LlamaModel",
|
||||
"LlamaPreTrainedModel",
|
||||
"LlamaForSequenceClassification",
|
||||
"LlamaForQuestionAnswering",
|
||||
"LlamaForTokenClassification",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -93,10 +89,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
else:
|
||||
from .modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -26,9 +26,10 @@ from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
|
||||
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import (
|
||||
@ -871,6 +872,54 @@ class LossKwargs(TypedDict, total=False):
|
||||
num_items_in_batch: Optional[int]
|
||||
|
||||
|
||||
class KwargsForCausalLM(LossKwargs):
|
||||
input_ids: torch.LongTensor = None
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
position_ids: Optional[torch.LongTensor] = None
|
||||
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None
|
||||
labels: Optional[torch.LongTensor] = None
|
||||
use_cache: Optional[bool] = None
|
||||
output_attentions: Optional[bool] = None
|
||||
output_hidden_states: Optional[bool] = None
|
||||
return_dict: Optional[bool] = None
|
||||
cache_position: Optional[torch.LongTensor] = None
|
||||
num_logits_to_keep: int = 0
|
||||
|
||||
|
||||
def validate_config_kwargs(func):
|
||||
"""
|
||||
A decorator to validate and initialize kwargs based on a config object.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
self = args[0]
|
||||
# Default values from the config
|
||||
default_kwargs = {
|
||||
"output_attentions": self.config.output_attentions,
|
||||
"output_hidden_states": self.config.output_hidden_states,
|
||||
"use_cache": self.config.use_cache,
|
||||
"return_dict": self.config.use_return_dict,
|
||||
}
|
||||
|
||||
# Merge provided kwargs with defaults
|
||||
validated_kwargs = {**default_kwargs, **kwargs}
|
||||
|
||||
# # Validate kwargs against TypedDict
|
||||
# for key in validated_kwargs:
|
||||
# if key not in KwargsForCausalLM.__annotations__:
|
||||
# raise ValueError(f"Invalid keyword argument: {key}")
|
||||
|
||||
if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]:
|
||||
validated_kwargs["use_cache"] = False
|
||||
|
||||
# Pass the validated kwargs to the function
|
||||
return func(*args, **validated_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks whether a config dict is a timm config dict."""
|
||||
return "pretrained_cfg" in config_dict
|
||||
|
@ -23,6 +23,12 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.models.auto import (
|
||||
AutoForCausalLM,
|
||||
AutoForQuestionAnswering,
|
||||
AutoForSequenceClassification,
|
||||
AutoForTokenClassification,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_flash_attn,
|
||||
@ -44,14 +50,9 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
||||
|
||||
|
||||
class LlamaModelTester:
|
||||
@ -197,7 +198,7 @@ class LlamaModelTester:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = LlamaForCausalLM(config=config)
|
||||
model = AutoForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
@ -217,7 +218,7 @@ class LlamaModelTester:
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = LlamaForCausalLM(config=config)
|
||||
model = AutoForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@ -285,23 +286,23 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
all_model_classes = (
|
||||
(
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForTokenClassification,
|
||||
AutoForCausalLM,
|
||||
AutoForSequenceClassification,
|
||||
AutoForQuestionAnswering,
|
||||
AutoForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (AutoForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": LlamaModel,
|
||||
"text-classification": LlamaForSequenceClassification,
|
||||
"text-generation": LlamaForCausalLM,
|
||||
"zero-shot": LlamaForSequenceClassification,
|
||||
"question-answering": LlamaForQuestionAnswering,
|
||||
"token-classification": LlamaForTokenClassification,
|
||||
"text-classification": AutoForSequenceClassification,
|
||||
"text-generation": AutoForCausalLM,
|
||||
"zero-shot": AutoForSequenceClassification,
|
||||
"question-answering": AutoForQuestionAnswering,
|
||||
"token-classification": AutoForTokenClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@ -315,7 +316,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
||||
_torch_compile_train_cls = AutoForCausalLM if is_torch_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LlamaModelTester(self)
|
||||
@ -340,7 +341,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = LlamaForSequenceClassification(config)
|
||||
model = AutoForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
@ -353,7 +354,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = LlamaForSequenceClassification(config)
|
||||
model = AutoForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
@ -368,7 +369,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
sequence_labels = ids_tensor(
|
||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
).to(torch.float)
|
||||
model = LlamaForSequenceClassification(config)
|
||||
model = AutoForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
@ -380,7 +381,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = LlamaForTokenClassification(config=config)
|
||||
model = AutoForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
@ -424,71 +425,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
def test_model_rope_scaling(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
scaling_factor = 10
|
||||
short_input_length = 10
|
||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
||||
linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
# with scaling_factor (or that `inv_freq` decreases)
|
||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
||||
ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||
|
||||
# Sanity check Yarn RoPE scaling
|
||||
# Scaling should be over the entire input
|
||||
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
||||
yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||
|
||||
def test_model_loading_old_rope_configs(self):
|
||||
def _reinitialize_config(base_config, new_kwargs):
|
||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
||||
@ -499,17 +435,17 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
# from untouched config -> ✅
|
||||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_model = LlamaForCausalLM(base_config).to(torch_device)
|
||||
original_model = AutoForCausalLM(base_config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the expected rope configuration -> ✅
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model = AutoForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model = AutoForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
||||
@ -518,13 +454,13 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
)
|
||||
self.assertTrue(config.rope_scaling["type"] == "linear")
|
||||
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model = AutoForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model = AutoForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("factor field", logs.output[0])
|
||||
@ -534,7 +470,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
||||
)
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model = AutoForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("Unrecognized keys", logs.output[0])
|
||||
@ -557,7 +493,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
model = model_class(config)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
new_model = LlamaForCausalLM.from_pretrained(
|
||||
new_model = AutoForCausalLM.from_pretrained(
|
||||
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
@ -608,7 +544,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoForCausalLM.from_pretrained(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
|
||||
)
|
||||
input_text = ["Tell me about the french revolution."]
|
||||
@ -623,7 +559,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
def test_model_7b_logits_bf16(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
)
|
||||
|
||||
@ -667,7 +603,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
@ -717,7 +653,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
@ -754,7 +690,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
@ -819,7 +755,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
batch_size = 1
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoForCausalLM.from_pretrained(
|
||||
llama_model_ckp,
|
||||
device_map=device,
|
||||
torch_dtype=dtype,
|
||||
@ -859,7 +795,7 @@ class Mask4DTestHard(unittest.TestCase):
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
||||
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
self.model = AutoForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
|
@ -943,78 +943,79 @@ class ModelTesterMixin:
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
with self.subTest(model_class.__name__):
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = 5
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = 5
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class.__name__ in [
|
||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||
]:
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class.__name__ in [
|
||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||
]:
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
encoder_key_length,
|
||||
],
|
||||
)
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
@ -1022,30 +1023,31 @@ class ModelTesterMixin:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
with self.subTest(model_class.__name__):
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_torchscript_simple(self):
|
||||
|
Reference in New Issue
Block a user