mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
42 Commits
v4.53.0
...
llama-refa
Author | SHA1 | Date | |
---|---|---|---|
5060a334de | |||
caaa5e5508 | |||
95cb944ee6 | |||
584b443096 | |||
57eece66af | |||
9461039d87 | |||
f7395cc0cc | |||
4f36712da1 | |||
2016bc47d0 | |||
53450ac365 | |||
1a5a834f53 | |||
3f68c7cf72 | |||
c224f36d10 | |||
725d00caf4 | |||
6028e85990 | |||
7a608da9f8 | |||
e9d751abaa | |||
60189825d7 | |||
d9156363bf | |||
20c512bc80 | |||
7a911efddf | |||
89d32d6825 | |||
3bbae39539 | |||
e5d60b4f23 | |||
4b9a429a1c | |||
1ef18f49a9 | |||
28829d2dd6 | |||
40154815cb | |||
38dd294dd7 | |||
1baabd3207 | |||
dcf7a37ce1 | |||
f61a5fec41 | |||
556aa4ec2d | |||
341b8ce9fa | |||
0418f97553 | |||
39ab8b757b | |||
13a195a7bb | |||
893ef382c4 | |||
4e681b9c72 | |||
0384db9c0c | |||
f446bd4c00 | |||
f14637a7b5 |
@ -191,6 +191,7 @@ _import_structure = {
|
|||||||
"AutoImageProcessor",
|
"AutoImageProcessor",
|
||||||
"AutoProcessor",
|
"AutoProcessor",
|
||||||
"AutoTokenizer",
|
"AutoTokenizer",
|
||||||
|
"AutoForCausalLM",
|
||||||
],
|
],
|
||||||
"models.autoformer": ["AutoformerConfig"],
|
"models.autoformer": ["AutoformerConfig"],
|
||||||
"models.bark": [
|
"models.bark": [
|
||||||
@ -2611,10 +2612,6 @@ else:
|
|||||||
)
|
)
|
||||||
_import_structure["models.llama"].extend(
|
_import_structure["models.llama"].extend(
|
||||||
[
|
[
|
||||||
"LlamaForCausalLM",
|
|
||||||
"LlamaForQuestionAnswering",
|
|
||||||
"LlamaForSequenceClassification",
|
|
||||||
"LlamaForTokenClassification",
|
|
||||||
"LlamaModel",
|
"LlamaModel",
|
||||||
"LlamaPreTrainedModel",
|
"LlamaPreTrainedModel",
|
||||||
]
|
]
|
||||||
@ -5084,6 +5081,7 @@ if TYPE_CHECKING:
|
|||||||
TOKENIZER_MAPPING,
|
TOKENIZER_MAPPING,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
|
AutoForCausalLM,
|
||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -6469,6 +6467,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForZeroShotObjectDetection,
|
AutoModelForZeroShotObjectDetection,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
from .models.auto.modeling_task import AutoForCausalLM
|
||||||
from .models.autoformer import (
|
from .models.autoformer import (
|
||||||
AutoformerForPrediction,
|
AutoformerForPrediction,
|
||||||
AutoformerModel,
|
AutoformerModel,
|
||||||
@ -7336,10 +7335,6 @@ if TYPE_CHECKING:
|
|||||||
LiltPreTrainedModel,
|
LiltPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.llama import (
|
from .models.llama import (
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaForQuestionAnswering,
|
|
||||||
LlamaForSequenceClassification,
|
|
||||||
LlamaForTokenClassification,
|
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaPreTrainedModel,
|
LlamaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
41
src/transformers/integrations/flash_attention.py
Normal file
41
src/transformers/integrations/flash_attention.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from ..modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_forward(
|
||||||
|
config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
|
||||||
|
):
|
||||||
|
if attention_mask is not None:
|
||||||
|
seq_len = attention_mask.shape[1]
|
||||||
|
query = query[:, :, :seq_len]
|
||||||
|
value = value[:, :, :seq_len]
|
||||||
|
else:
|
||||||
|
seq_len = query.shape[1]
|
||||||
|
|
||||||
|
# Re-transpose them
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = config.attention_dropout if training else 0.0
|
||||||
|
|
||||||
|
input_dtype = query.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
query = query.to(target_dtype)
|
||||||
|
key = key.to(target_dtype)
|
||||||
|
value = value.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
seq_len,
|
||||||
|
config=config,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output, None
|
28
src/transformers/integrations/flex_attention.py
Normal file
28
src/transformers/integrations/flex_attention.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from ..utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_greater_or_equal("2.5"):
|
||||||
|
from torch.nn.attention.flex_attention import flex_attention
|
||||||
|
|
||||||
|
|
||||||
|
def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs):
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if causal_mask is not None:
|
||||||
|
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||||
|
|
||||||
|
def causal_mod(score, b, h, q_idx, kv_idx):
|
||||||
|
if causal_mask is not None:
|
||||||
|
score += causal_mask[b][0][q_idx][kv_idx]
|
||||||
|
return score
|
||||||
|
|
||||||
|
attn_output, attention_weights = flex_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
score_mod=causal_mod,
|
||||||
|
enable_gqa=True,
|
||||||
|
scale=module.scaling,
|
||||||
|
return_lse=True,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
return attn_output, attention_weights
|
39
src/transformers/integrations/sdpa_attention.py
Normal file
39
src/transformers/integrations/sdpa_attention.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs):
|
||||||
|
key = repeat_kv(key, module.num_key_value_groups)
|
||||||
|
value = repeat_kv(value, module.num_key_value_groups)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||||
|
|
||||||
|
query = query.contiguous()
|
||||||
|
key = key.contiguous()
|
||||||
|
value = value.contiguous()
|
||||||
|
|
||||||
|
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=module.config.attention_dropout if module.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
scale=module.scaling,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
return attn_output, None
|
@ -276,7 +276,7 @@ def _flash_attention_forward(
|
|||||||
if not use_top_left_mask:
|
if not use_top_left_mask:
|
||||||
causal = is_causal
|
causal = is_causal
|
||||||
else:
|
else:
|
||||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.mistral.modeling_mistral.MistralFlashAttention2.__init__.
|
||||||
causal = is_causal and query_length != 1
|
causal = is_causal and query_length != 1
|
||||||
|
|
||||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||||
|
@ -45,6 +45,9 @@ from .configuration_utils import PretrainedConfig
|
|||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
|
from .integrations.flash_attention import flash_attention_forward
|
||||||
|
from .integrations.flex_attention import flex_attention_forward
|
||||||
|
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||||
from .loss.loss_utils import LOSS_MAPPING
|
from .loss.loss_utils import LOSS_MAPPING
|
||||||
from .pytorch_utils import ( # noqa: F401
|
from .pytorch_utils import ( # noqa: F401
|
||||||
Conv1D,
|
Conv1D,
|
||||||
@ -1290,6 +1293,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# `config.base_model_tp_plan` during `post_init`.
|
# `config.base_model_tp_plan` during `post_init`.
|
||||||
_tp_plan = None
|
_tp_plan = None
|
||||||
|
|
||||||
|
_output_embedding = None
|
||||||
|
_input_embedding = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@ -1487,7 +1493,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
message += (
|
message += (
|
||||||
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||||
)
|
)
|
||||||
raise ValueError(message + ".")
|
if config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(message + ".")
|
||||||
|
|
||||||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
||||||
requested_attn_implementation = config._attn_implementation_internal
|
requested_attn_implementation = config._attn_implementation_internal
|
||||||
@ -1525,10 +1534,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
|
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
|
||||||
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
|
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
|
||||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||||
config = cls._check_and_enable_sdpa(
|
# config = cls._check_and_enable_sdpa(
|
||||||
config,
|
# config,
|
||||||
hard_check_only=False if requested_attn_implementation is None else True,
|
# hard_check_only=False if requested_attn_implementation is None else True,
|
||||||
)
|
# )
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.version.hip is not None
|
torch.version.hip is not None
|
||||||
@ -1539,6 +1548,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
||||||
)
|
)
|
||||||
torch.backends.cuda.enable_flash_sdp(False)
|
torch.backends.cuda.enable_flash_sdp(False)
|
||||||
|
elif config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
|
||||||
|
pass
|
||||||
elif isinstance(requested_attn_implementation, dict):
|
elif isinstance(requested_attn_implementation, dict):
|
||||||
config._attn_implementation = None
|
config._attn_implementation = None
|
||||||
else:
|
else:
|
||||||
@ -1801,7 +1812,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if base_model is not self:
|
if base_model is not self:
|
||||||
return base_model.get_input_embeddings()
|
return base_model.get_input_embeddings()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
return getattr(self, self._input_embedding)
|
||||||
|
|
||||||
def set_input_embeddings(self, value: nn.Module):
|
def set_input_embeddings(self, value: nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -1813,8 +1824,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
base_model = getattr(self, self.base_model_prefix, self)
|
base_model = getattr(self, self.base_model_prefix, self)
|
||||||
if base_model is not self:
|
if base_model is not self:
|
||||||
base_model.set_input_embeddings(value)
|
base_model.set_input_embeddings(value)
|
||||||
|
elif self._input_embedding is not None:
|
||||||
|
setattr(self, self._input_embedding, value)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise ValueError("No input embedding")
|
||||||
|
|
||||||
def get_output_embeddings(self) -> nn.Module:
|
def get_output_embeddings(self) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
@ -1823,7 +1836,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
Returns:
|
Returns:
|
||||||
`nn.Module`: A torch module mapping hidden states to vocabulary.
|
`nn.Module`: A torch module mapping hidden states to vocabulary.
|
||||||
"""
|
"""
|
||||||
return None # Overwrite for models with output embeddings
|
if self._output_embedding is not None:
|
||||||
|
return getattr(self, self._output_embedding, None)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_output_embeddings(self, value: nn.Module):
|
||||||
|
"""
|
||||||
|
Set model's input embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (`nn.Module`): A module mapping vocabulary to hidden states.
|
||||||
|
"""
|
||||||
|
base_model = getattr(self, self.base_model_prefix, self)
|
||||||
|
if base_model is not self:
|
||||||
|
base_model.set_output_embeddings(value)
|
||||||
|
elif self._output_embedding is not None:
|
||||||
|
setattr(self, self._output_embedding, value)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""
|
"""
|
||||||
@ -1832,7 +1863,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
using `from_pretrained`. Any attempt to initialize outside of this function
|
using `from_pretrained`. Any attempt to initialize outside of this function
|
||||||
will be useless as the torch.nn.init function are all replaced with skip.
|
will be useless as the torch.nn.init function are all replaced with skip.
|
||||||
"""
|
"""
|
||||||
pass
|
std = self.config.initializer_range
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def _initialize_weights(self, module):
|
def _initialize_weights(self, module):
|
||||||
"""
|
"""
|
||||||
@ -2509,91 +2548,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
self.base_model._prune_heads(heads_to_prune)
|
self.base_model._prune_heads(heads_to_prune)
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||||
"""
|
for layer in list(self.modules()):
|
||||||
Activates gradient checkpointing for the current model.
|
if isinstance(layer, GradientCheckpointLayer):
|
||||||
|
layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
@property
|
||||||
activations".
|
def gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||||
|
for layer in list(self.modules()):
|
||||||
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
|
if isinstance(layer, GradientCheckpointLayer):
|
||||||
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
return layer.gradient_checkpointing
|
||||||
|
return False
|
||||||
Args:
|
|
||||||
gradient_checkpointing_kwargs (dict, *optional*):
|
|
||||||
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
|
||||||
"""
|
|
||||||
if not self.supports_gradient_checkpointing:
|
|
||||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
|
||||||
|
|
||||||
if gradient_checkpointing_kwargs is None:
|
|
||||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
|
||||||
|
|
||||||
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
|
||||||
|
|
||||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
|
||||||
# we will fall back to the overwritten `_set_gradient_checkpointing` method
|
|
||||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
|
||||||
|
|
||||||
if not _is_using_old_format:
|
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
|
||||||
else:
|
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
|
||||||
logger.warning(
|
|
||||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
|
||||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(self, "_hf_peft_config_loaded", False):
|
|
||||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
|
||||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
|
||||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
|
||||||
# the gradients to make sure the gradient flows.
|
|
||||||
self.enable_input_require_grads()
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
|
|
||||||
is_gradient_checkpointing_set = False
|
|
||||||
|
|
||||||
# Apply it on the top-level module in case the top-level modules supports it
|
|
||||||
# for example, LongT5Stack inherits from `PreTrainedModel`.
|
|
||||||
if hasattr(self, "gradient_checkpointing"):
|
|
||||||
self._gradient_checkpointing_func = gradient_checkpointing_func
|
|
||||||
self.gradient_checkpointing = enable
|
|
||||||
is_gradient_checkpointing_set = True
|
|
||||||
|
|
||||||
for module in self.modules():
|
|
||||||
if hasattr(module, "gradient_checkpointing"):
|
|
||||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
|
||||||
module.gradient_checkpointing = enable
|
|
||||||
is_gradient_checkpointing_set = True
|
|
||||||
|
|
||||||
if not is_gradient_checkpointing_set:
|
|
||||||
raise ValueError(
|
|
||||||
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
|
|
||||||
" `gradient_checkpointing` to modules of the model that uses checkpointing."
|
|
||||||
)
|
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
|
||||||
"""
|
|
||||||
Deactivates gradient checkpointing for the current model.
|
|
||||||
|
|
||||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
|
||||||
activations".
|
|
||||||
"""
|
|
||||||
if self.supports_gradient_checkpointing:
|
|
||||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
|
||||||
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
|
|
||||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
|
||||||
if not _is_using_old_format:
|
|
||||||
self._set_gradient_checkpointing(enable=False)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
|
||||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
|
||||||
)
|
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
|
||||||
|
|
||||||
if getattr(self, "_hf_peft_config_loaded", False):
|
|
||||||
self.disable_input_require_grads()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_gradient_checkpointing(self) -> bool:
|
def is_gradient_checkpointing(self) -> bool:
|
||||||
@ -5626,3 +5590,139 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
|||||||
files_content[filename].append(device_map[weight_name])
|
files_content[filename].append(device_map[weight_name])
|
||||||
|
|
||||||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||||
|
|
||||||
|
|
||||||
|
ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {}
|
||||||
|
|
||||||
|
ALL_ATTENTION_FUNCTIONS.update(
|
||||||
|
{
|
||||||
|
"flash_attention_2": flash_attention_forward,
|
||||||
|
"flex_attention": flex_attention_forward,
|
||||||
|
"sdpa": sdpa_attention_forward,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientCheckpointLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
super().__init__( *args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Adjust the behavior of the inherited class by overriding `__call__`.
|
||||||
|
|
||||||
|
Automatically handles gradient checkpointing based on flags in the provided arguments.
|
||||||
|
"""
|
||||||
|
# Extract necessary flags and arguments
|
||||||
|
gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr(
|
||||||
|
self, "gradient_checkpointing", False
|
||||||
|
)
|
||||||
|
training = self.training
|
||||||
|
|
||||||
|
if gradient_checkpointing and training:
|
||||||
|
# Use gradient checkpointing
|
||||||
|
return self._apply_gradient_checkpointing(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# Default behavior: call the original `forward` method
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def _apply_gradient_checkpointing(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Apply gradient checkpointing using the appropriate function.
|
||||||
|
|
||||||
|
By default, uses `torch.utils.checkpoint.checkpoint`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Assume `self.forward` is compatible with checkpointing
|
||||||
|
def wrapped_forward():
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
return self._gradient_checkpointing_func(wrapped_forward)
|
||||||
|
|
||||||
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||||
|
"""
|
||||||
|
Activates gradient checkpointing for the current model.
|
||||||
|
|
||||||
|
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||||
|
activations".
|
||||||
|
|
||||||
|
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
|
||||||
|
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradient_checkpointing_kwargs (dict, *optional*):
|
||||||
|
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
||||||
|
"""
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
|
||||||
|
if gradient_checkpointing_kwargs is None:
|
||||||
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
|
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
|
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||||
|
# we will fall back to the overwritten `_set_gradient_checkpointing` method
|
||||||
|
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||||
|
|
||||||
|
if not _is_using_old_format:
|
||||||
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||||
|
else:
|
||||||
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
|
logger.warning(
|
||||||
|
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||||
|
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(self, "_hf_peft_config_loaded", False):
|
||||||
|
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||||
|
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||||
|
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||||
|
# the gradients to make sure the gradient flows.
|
||||||
|
self.enable_input_require_grads()
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
|
||||||
|
is_gradient_checkpointing_set = False
|
||||||
|
|
||||||
|
# Apply it on the top-level module in case the top-level modules supports it
|
||||||
|
# for example, LongT5Stack inherits from `PreTrainedModel`.
|
||||||
|
if hasattr(self, "gradient_checkpointing"):
|
||||||
|
self._gradient_checkpointing_func = gradient_checkpointing_func
|
||||||
|
self.gradient_checkpointing = enable
|
||||||
|
is_gradient_checkpointing_set = True
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
|
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||||
|
module.gradient_checkpointing = enable
|
||||||
|
is_gradient_checkpointing_set = True
|
||||||
|
|
||||||
|
if not is_gradient_checkpointing_set:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
|
||||||
|
" `gradient_checkpointing` to modules of the model that uses checkpointing."
|
||||||
|
)
|
||||||
|
|
||||||
|
def gradient_checkpointing_disable(self):
|
||||||
|
"""
|
||||||
|
Deactivates gradient checkpointing for the current model.
|
||||||
|
|
||||||
|
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||||
|
activations".
|
||||||
|
"""
|
||||||
|
if self.supports_gradient_checkpointing:
|
||||||
|
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||||
|
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
|
||||||
|
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||||
|
if not _is_using_old_format:
|
||||||
|
self._set_gradient_checkpointing(enable=False)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||||
|
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||||
|
)
|
||||||
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||||
|
|
||||||
|
if getattr(self, "_hf_peft_config_loaded", False):
|
||||||
|
self.disable_input_require_grads()
|
||||||
|
@ -122,6 +122,12 @@ else:
|
|||||||
"AutoModelForZeroShotObjectDetection",
|
"AutoModelForZeroShotObjectDetection",
|
||||||
"AutoModelForImageTextToText",
|
"AutoModelForImageTextToText",
|
||||||
]
|
]
|
||||||
|
_import_structure["modeling_task"] = [
|
||||||
|
"AutoForCausalLM",
|
||||||
|
"AutoForSequenceClassification",
|
||||||
|
"AutoForQuestionAnswering",
|
||||||
|
"AutoForTokenClassification",
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
@ -311,6 +317,12 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForZeroShotObjectDetection,
|
AutoModelForZeroShotObjectDetection,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
from .modeling_task import (
|
||||||
|
AutoForCausalLM,
|
||||||
|
AutoForQuestionAnswering,
|
||||||
|
AutoForSequenceClassification,
|
||||||
|
AutoForTokenClassification,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
|
@ -438,7 +438,9 @@ class _BaseAutoModelClass:
|
|||||||
elif type(config) in cls._model_mapping.keys():
|
elif type(config) in cls._model_mapping.keys():
|
||||||
model_class = _get_model_class(config, cls._model_mapping)
|
model_class = _get_model_class(config, cls._model_mapping)
|
||||||
return model_class._from_config(config, **kwargs)
|
return model_class._from_config(config, **kwargs)
|
||||||
|
else:
|
||||||
|
model_class = cls._model_mapping["auto"]
|
||||||
|
return model_class._from_config(config, **kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
@ -564,6 +566,11 @@ class _BaseAutoModelClass:
|
|||||||
return model_class.from_pretrained(
|
return model_class.from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
model_class = cls._model_mapping[PretrainedConfig]
|
||||||
|
return model_class.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
|
@ -39,6 +39,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("aria_text", "AriaTextConfig"),
|
("aria_text", "AriaTextConfig"),
|
||||||
("audio-spectrogram-transformer", "ASTConfig"),
|
("audio-spectrogram-transformer", "ASTConfig"),
|
||||||
("autoformer", "AutoformerConfig"),
|
("autoformer", "AutoformerConfig"),
|
||||||
|
("auto", "PretrainedConfig"),
|
||||||
("bark", "BarkConfig"),
|
("bark", "BarkConfig"),
|
||||||
("bart", "BartConfig"),
|
("bart", "BartConfig"),
|
||||||
("beit", "BeitConfig"),
|
("beit", "BeitConfig"),
|
||||||
@ -335,6 +336,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("aria_text", "AriaText"),
|
("aria_text", "AriaText"),
|
||||||
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
||||||
("autoformer", "Autoformer"),
|
("autoformer", "Autoformer"),
|
||||||
|
("auto", "Auto"),
|
||||||
("bark", "Bark"),
|
("bark", "Bark"),
|
||||||
("bart", "BART"),
|
("bart", "BART"),
|
||||||
("barthez", "BARThez"),
|
("barthez", "BARThez"),
|
||||||
@ -912,6 +914,8 @@ class AutoConfig:
|
|||||||
This class cannot be instantiated directly using `__init__()` (throws an error).
|
This class cannot be instantiated directly using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_type = "auto"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"AutoConfig is designed to be instantiated "
|
"AutoConfig is designed to be instantiated "
|
||||||
|
@ -468,6 +468,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Causal LM mapping
|
# Model for Causal LM mapping
|
||||||
|
("auto", "AutoForCausalLM"),
|
||||||
("aria_text", "AriaTextForCausalLM"),
|
("aria_text", "AriaTextForCausalLM"),
|
||||||
("bart", "BartForCausalLM"),
|
("bart", "BartForCausalLM"),
|
||||||
("bert", "BertLMHeadModel"),
|
("bert", "BertLMHeadModel"),
|
||||||
@ -479,7 +480,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("blenderbot-small", "BlenderbotSmallForCausalLM"),
|
("blenderbot-small", "BlenderbotSmallForCausalLM"),
|
||||||
("bloom", "BloomForCausalLM"),
|
("bloom", "BloomForCausalLM"),
|
||||||
("camembert", "CamembertForCausalLM"),
|
("camembert", "CamembertForCausalLM"),
|
||||||
("code_llama", "LlamaForCausalLM"),
|
|
||||||
("codegen", "CodeGenForCausalLM"),
|
("codegen", "CodeGenForCausalLM"),
|
||||||
("cohere", "CohereForCausalLM"),
|
("cohere", "CohereForCausalLM"),
|
||||||
("cpmant", "CpmAntForCausalLM"),
|
("cpmant", "CpmAntForCausalLM"),
|
||||||
@ -506,7 +506,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("granitemoe", "GraniteMoeForCausalLM"),
|
("granitemoe", "GraniteMoeForCausalLM"),
|
||||||
("jamba", "JambaForCausalLM"),
|
("jamba", "JambaForCausalLM"),
|
||||||
("jetmoe", "JetMoeForCausalLM"),
|
("jetmoe", "JetMoeForCausalLM"),
|
||||||
("llama", "LlamaForCausalLM"),
|
|
||||||
("mamba", "MambaForCausalLM"),
|
("mamba", "MambaForCausalLM"),
|
||||||
("mamba2", "Mamba2ForCausalLM"),
|
("mamba2", "Mamba2ForCausalLM"),
|
||||||
("marian", "MarianForCausalLM"),
|
("marian", "MarianForCausalLM"),
|
||||||
@ -931,6 +930,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Model for Sequence Classification mapping
|
# Model for Sequence Classification mapping
|
||||||
("albert", "AlbertForSequenceClassification"),
|
("albert", "AlbertForSequenceClassification"),
|
||||||
|
("auto", "AutoForSequenceClassification"),
|
||||||
("bart", "BartForSequenceClassification"),
|
("bart", "BartForSequenceClassification"),
|
||||||
("bert", "BertForSequenceClassification"),
|
("bert", "BertForSequenceClassification"),
|
||||||
("big_bird", "BigBirdForSequenceClassification"),
|
("big_bird", "BigBirdForSequenceClassification"),
|
||||||
@ -939,7 +939,6 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("bloom", "BloomForSequenceClassification"),
|
("bloom", "BloomForSequenceClassification"),
|
||||||
("camembert", "CamembertForSequenceClassification"),
|
("camembert", "CamembertForSequenceClassification"),
|
||||||
("canine", "CanineForSequenceClassification"),
|
("canine", "CanineForSequenceClassification"),
|
||||||
("code_llama", "LlamaForSequenceClassification"),
|
|
||||||
("convbert", "ConvBertForSequenceClassification"),
|
("convbert", "ConvBertForSequenceClassification"),
|
||||||
("ctrl", "CTRLForSequenceClassification"),
|
("ctrl", "CTRLForSequenceClassification"),
|
||||||
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
||||||
@ -971,7 +970,6 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
|
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
|
||||||
("led", "LEDForSequenceClassification"),
|
("led", "LEDForSequenceClassification"),
|
||||||
("lilt", "LiltForSequenceClassification"),
|
("lilt", "LiltForSequenceClassification"),
|
||||||
("llama", "LlamaForSequenceClassification"),
|
|
||||||
("longformer", "LongformerForSequenceClassification"),
|
("longformer", "LongformerForSequenceClassification"),
|
||||||
("luke", "LukeForSequenceClassification"),
|
("luke", "LukeForSequenceClassification"),
|
||||||
("markuplm", "MarkupLMForSequenceClassification"),
|
("markuplm", "MarkupLMForSequenceClassification"),
|
||||||
@ -1028,6 +1026,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Model for Question Answering mapping
|
# Model for Question Answering mapping
|
||||||
("albert", "AlbertForQuestionAnswering"),
|
("albert", "AlbertForQuestionAnswering"),
|
||||||
|
("auto", "AutoForQuestionAnswering"),
|
||||||
("bart", "BartForQuestionAnswering"),
|
("bart", "BartForQuestionAnswering"),
|
||||||
("bert", "BertForQuestionAnswering"),
|
("bert", "BertForQuestionAnswering"),
|
||||||
("big_bird", "BigBirdForQuestionAnswering"),
|
("big_bird", "BigBirdForQuestionAnswering"),
|
||||||
@ -1056,7 +1055,6 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
||||||
("led", "LEDForQuestionAnswering"),
|
("led", "LEDForQuestionAnswering"),
|
||||||
("lilt", "LiltForQuestionAnswering"),
|
("lilt", "LiltForQuestionAnswering"),
|
||||||
("llama", "LlamaForQuestionAnswering"),
|
|
||||||
("longformer", "LongformerForQuestionAnswering"),
|
("longformer", "LongformerForQuestionAnswering"),
|
||||||
("luke", "LukeForQuestionAnswering"),
|
("luke", "LukeForQuestionAnswering"),
|
||||||
("lxmert", "LxmertForQuestionAnswering"),
|
("lxmert", "LxmertForQuestionAnswering"),
|
||||||
@ -1125,6 +1123,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Model for Token Classification mapping
|
# Model for Token Classification mapping
|
||||||
("albert", "AlbertForTokenClassification"),
|
("albert", "AlbertForTokenClassification"),
|
||||||
|
("auto", "AutoForTokenClassification"),
|
||||||
("bert", "BertForTokenClassification"),
|
("bert", "BertForTokenClassification"),
|
||||||
("big_bird", "BigBirdForTokenClassification"),
|
("big_bird", "BigBirdForTokenClassification"),
|
||||||
("biogpt", "BioGptForTokenClassification"),
|
("biogpt", "BioGptForTokenClassification"),
|
||||||
@ -1158,7 +1157,6 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||||
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
||||||
("lilt", "LiltForTokenClassification"),
|
("lilt", "LiltForTokenClassification"),
|
||||||
("llama", "LlamaForTokenClassification"),
|
|
||||||
("longformer", "LongformerForTokenClassification"),
|
("longformer", "LongformerForTokenClassification"),
|
||||||
("luke", "LukeForTokenClassification"),
|
("luke", "LukeForTokenClassification"),
|
||||||
("markuplm", "MarkupLMForTokenClassification"),
|
("markuplm", "MarkupLMForTokenClassification"),
|
||||||
|
298
src/transformers/models/auto/modeling_task.py
Normal file
298
src/transformers/models/auto/modeling_task.py
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...cache_utils import Cache
|
||||||
|
from ...generation import GenerationMixin
|
||||||
|
from ...modeling_outputs import (
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils.generic import KwargsForCausalLM, validate_config_kwargs
|
||||||
|
from ..auto import AutoConfig, AutoModel
|
||||||
|
|
||||||
|
|
||||||
|
class AutoForCausalLM(PreTrainedModel, GenerationMixin):
|
||||||
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_output_embedding = "lm_head"
|
||||||
|
_no_split_modules = []
|
||||||
|
_supports_cache_class = True
|
||||||
|
config_class = AutoConfig
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = AutoModel.from_config(config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value: nn.Module):
|
||||||
|
self.model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
num_logits_to_keep: int = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
|
||||||
|
output = CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output if return_dict else output.to_tuple()
|
||||||
|
|
||||||
|
|
||||||
|
class AutoForSequenceClassification(PreTrainedModel):
|
||||||
|
config_class = AutoConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = AutoModel.from_config(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
logits = self.score(hidden_states)
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
else:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
|
||||||
|
if self.config.pad_token_id is None and batch_size != 1:
|
||||||
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||||
|
if self.config.pad_token_id is None:
|
||||||
|
sequence_lengths = -1
|
||||||
|
else:
|
||||||
|
if input_ids is not None:
|
||||||
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||||
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||||
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||||
|
sequence_lengths = sequence_lengths.to(logits.device)
|
||||||
|
else:
|
||||||
|
sequence_lengths = -1
|
||||||
|
|
||||||
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (pooled_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=pooled_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoForQuestionAnswering(PreTrainedModel):
|
||||||
|
base_model_prefix = "transformer"
|
||||||
|
config_class = AutoConfig
|
||||||
|
|
||||||
|
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.transformer = AutoModel.from_config(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.transformer.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.transformer.embed_tokens = value
|
||||||
|
|
||||||
|
@validate_config_kwargs
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoForTokenClassification(PreTrainedModel):
|
||||||
|
config_class = AutoConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = AutoModel.from_config(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, TokenClassifierOutput]:
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.config)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
@ -50,12 +50,8 @@ except OptionalDependencyNotAvailable:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_import_structure["modeling_llama"] = [
|
_import_structure["modeling_llama"] = [
|
||||||
"LlamaForCausalLM",
|
|
||||||
"LlamaModel",
|
"LlamaModel",
|
||||||
"LlamaPreTrainedModel",
|
"LlamaPreTrainedModel",
|
||||||
"LlamaForSequenceClassification",
|
|
||||||
"LlamaForQuestionAnswering",
|
|
||||||
"LlamaForTokenClassification",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -93,10 +89,6 @@ if TYPE_CHECKING:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
from .modeling_llama import (
|
from .modeling_llama import (
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaForQuestionAnswering,
|
|
||||||
LlamaForSequenceClassification,
|
|
||||||
LlamaForTokenClassification,
|
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaPreTrainedModel,
|
LlamaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -26,9 +26,10 @@ from contextlib import ExitStack, contextmanager
|
|||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
|
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict, Union, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from .import_utils import (
|
from .import_utils import (
|
||||||
@ -871,6 +872,54 @@ class LossKwargs(TypedDict, total=False):
|
|||||||
num_items_in_batch: Optional[int]
|
num_items_in_batch: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class KwargsForCausalLM(LossKwargs):
|
||||||
|
input_ids: torch.LongTensor = None
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
position_ids: Optional[torch.LongTensor] = None
|
||||||
|
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
labels: Optional[torch.LongTensor] = None
|
||||||
|
use_cache: Optional[bool] = None
|
||||||
|
output_attentions: Optional[bool] = None
|
||||||
|
output_hidden_states: Optional[bool] = None
|
||||||
|
return_dict: Optional[bool] = None
|
||||||
|
cache_position: Optional[torch.LongTensor] = None
|
||||||
|
num_logits_to_keep: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config_kwargs(func):
|
||||||
|
"""
|
||||||
|
A decorator to validate and initialize kwargs based on a config object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
self = args[0]
|
||||||
|
# Default values from the config
|
||||||
|
default_kwargs = {
|
||||||
|
"output_attentions": self.config.output_attentions,
|
||||||
|
"output_hidden_states": self.config.output_hidden_states,
|
||||||
|
"use_cache": self.config.use_cache,
|
||||||
|
"return_dict": self.config.use_return_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge provided kwargs with defaults
|
||||||
|
validated_kwargs = {**default_kwargs, **kwargs}
|
||||||
|
|
||||||
|
# # Validate kwargs against TypedDict
|
||||||
|
# for key in validated_kwargs:
|
||||||
|
# if key not in KwargsForCausalLM.__annotations__:
|
||||||
|
# raise ValueError(f"Invalid keyword argument: {key}")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and default_kwargs["use_cache"]:
|
||||||
|
validated_kwargs["use_cache"] = False
|
||||||
|
|
||||||
|
# Pass the validated kwargs to the function
|
||||||
|
return func(*args, **validated_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
|
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
|
||||||
"""Checks whether a config dict is a timm config dict."""
|
"""Checks whether a config dict is a timm config dict."""
|
||||||
return "pretrained_cfg" in config_dict
|
return "pretrained_cfg" in config_dict
|
||||||
|
@ -23,6 +23,12 @@ from parameterized import parameterized
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
|
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
|
from transformers.models.auto import (
|
||||||
|
AutoForCausalLM,
|
||||||
|
AutoForQuestionAnswering,
|
||||||
|
AutoForSequenceClassification,
|
||||||
|
AutoForTokenClassification,
|
||||||
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
cleanup,
|
cleanup,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@ -44,14 +50,9 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaForQuestionAnswering,
|
|
||||||
LlamaForSequenceClassification,
|
|
||||||
LlamaForTokenClassification,
|
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaModelTester:
|
class LlamaModelTester:
|
||||||
@ -197,7 +198,7 @@ class LlamaModelTester:
|
|||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
):
|
):
|
||||||
model = LlamaForCausalLM(config=config)
|
model = AutoForCausalLM(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
@ -217,7 +218,7 @@ class LlamaModelTester:
|
|||||||
):
|
):
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
config.add_cross_attention = True
|
config.add_cross_attention = True
|
||||||
model = LlamaForCausalLM(config=config)
|
model = AutoForCausalLM(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -285,23 +286,23 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaForCausalLM,
|
AutoForCausalLM,
|
||||||
LlamaForSequenceClassification,
|
AutoForSequenceClassification,
|
||||||
LlamaForQuestionAnswering,
|
AutoForQuestionAnswering,
|
||||||
LlamaForTokenClassification,
|
AutoForTokenClassification,
|
||||||
)
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (AutoForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": LlamaModel,
|
"feature-extraction": LlamaModel,
|
||||||
"text-classification": LlamaForSequenceClassification,
|
"text-classification": AutoForSequenceClassification,
|
||||||
"text-generation": LlamaForCausalLM,
|
"text-generation": AutoForCausalLM,
|
||||||
"zero-shot": LlamaForSequenceClassification,
|
"zero-shot": AutoForSequenceClassification,
|
||||||
"question-answering": LlamaForQuestionAnswering,
|
"question-answering": AutoForQuestionAnswering,
|
||||||
"token-classification": LlamaForTokenClassification,
|
"token-classification": AutoForTokenClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
@ -315,7 +316,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
model_split_percents = [0.5, 0.7, 0.8]
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
# used in `test_torch_compile_for_training`
|
# used in `test_torch_compile_for_training`
|
||||||
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
_torch_compile_train_cls = AutoForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = LlamaModelTester(self)
|
self.model_tester = LlamaModelTester(self)
|
||||||
@ -340,7 +341,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||||
model = LlamaForSequenceClassification(config)
|
model = AutoForSequenceClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
@ -353,7 +354,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||||
model = LlamaForSequenceClassification(config)
|
model = AutoForSequenceClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
@ -368,7 +369,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
sequence_labels = ids_tensor(
|
sequence_labels = ids_tensor(
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||||
).to(torch.float)
|
).to(torch.float)
|
||||||
model = LlamaForSequenceClassification(config)
|
model = AutoForSequenceClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
@ -380,7 +381,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
model = LlamaForTokenClassification(config=config)
|
model = AutoForTokenClassification(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
@ -424,71 +425,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# The output should be different for long inputs
|
# The output should be different for long inputs
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
def test_model_rope_scaling(self):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
scaling_factor = 10
|
|
||||||
short_input_length = 10
|
|
||||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
|
||||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_short = position_ids_short.unsqueeze(0)
|
|
||||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_long = position_ids_long.unsqueeze(0)
|
|
||||||
|
|
||||||
# Sanity check original RoPE
|
|
||||||
original_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
|
||||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
|
||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
|
||||||
linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
|
||||||
for new_position in range(0, long_input_length, scaling_factor):
|
|
||||||
original_position = int(new_position // scaling_factor)
|
|
||||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
|
||||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
|
||||||
|
|
||||||
# Sanity check Dynamic NTK RoPE scaling
|
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
|
||||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
|
||||||
ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
||||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
||||||
|
|
||||||
# Sanity check Yarn RoPE scaling
|
|
||||||
# Scaling should be over the entire input
|
|
||||||
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
|
||||||
yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
|
||||||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
|
||||||
|
|
||||||
def test_model_loading_old_rope_configs(self):
|
def test_model_loading_old_rope_configs(self):
|
||||||
def _reinitialize_config(base_config, new_kwargs):
|
def _reinitialize_config(base_config, new_kwargs):
|
||||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
||||||
@ -499,17 +435,17 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
|
|
||||||
# from untouched config -> ✅
|
# from untouched config -> ✅
|
||||||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
original_model = LlamaForCausalLM(base_config).to(torch_device)
|
original_model = AutoForCausalLM(base_config).to(torch_device)
|
||||||
original_model(**model_inputs)
|
original_model(**model_inputs)
|
||||||
|
|
||||||
# from a config with the expected rope configuration -> ✅
|
# from a config with the expected rope configuration -> ✅
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
original_model = AutoForCausalLM(config).to(torch_device)
|
||||||
original_model(**model_inputs)
|
original_model(**model_inputs)
|
||||||
|
|
||||||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
original_model = AutoForCausalLM(config).to(torch_device)
|
||||||
original_model(**model_inputs)
|
original_model(**model_inputs)
|
||||||
|
|
||||||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
||||||
@ -518,13 +454,13 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
)
|
)
|
||||||
self.assertTrue(config.rope_scaling["type"] == "linear")
|
self.assertTrue(config.rope_scaling["type"] == "linear")
|
||||||
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
original_model = AutoForCausalLM(config).to(torch_device)
|
||||||
original_model(**model_inputs)
|
original_model(**model_inputs)
|
||||||
|
|
||||||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
||||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
original_model = AutoForCausalLM(config).to(torch_device)
|
||||||
original_model(**model_inputs)
|
original_model(**model_inputs)
|
||||||
self.assertEqual(len(logs.output), 1)
|
self.assertEqual(len(logs.output), 1)
|
||||||
self.assertIn("factor field", logs.output[0])
|
self.assertIn("factor field", logs.output[0])
|
||||||
@ -534,7 +470,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
config = _reinitialize_config(
|
config = _reinitialize_config(
|
||||||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
||||||
)
|
)
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
original_model = AutoForCausalLM(config).to(torch_device)
|
||||||
original_model(**model_inputs)
|
original_model(**model_inputs)
|
||||||
self.assertEqual(len(logs.output), 1)
|
self.assertEqual(len(logs.output), 1)
|
||||||
self.assertIn("Unrecognized keys", logs.output[0])
|
self.assertIn("Unrecognized keys", logs.output[0])
|
||||||
@ -557,7 +493,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
new_model = LlamaForCausalLM.from_pretrained(
|
new_model = AutoForCausalLM.from_pretrained(
|
||||||
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
|
|
||||||
@ -608,7 +544,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoForCausalLM.from_pretrained(
|
||||||
"meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
|
"meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
input_text = ["Tell me about the french revolution."]
|
input_text = ["Tell me about the french revolution."]
|
||||||
@ -623,7 +559,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
def test_model_7b_logits_bf16(self):
|
def test_model_7b_logits_bf16(self):
|
||||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -667,7 +603,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
def test_model_7b_logits(self):
|
def test_model_7b_logits(self):
|
||||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -717,7 +653,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
prompt = "Simply put, the theory of relativity states that "
|
prompt = "Simply put, the theory of relativity states that "
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
|
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
@ -754,7 +690,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
"My favorite all time favorite condiment is ketchup.",
|
"My favorite all time favorite condiment is ketchup.",
|
||||||
]
|
]
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
|
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
@ -819,7 +755,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoForCausalLM.from_pretrained(
|
||||||
llama_model_ckp,
|
llama_model_ckp,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
@ -859,7 +795,7 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
self.model_dtype = torch.float32
|
self.model_dtype = torch.float32
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
||||||
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
self.model = AutoForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||||
|
|
||||||
def get_test_data(self):
|
def get_test_data(self):
|
||||||
template = "my favorite {}"
|
template = "my favorite {}"
|
||||||
|
@ -943,78 +943,79 @@ class ModelTesterMixin:
|
|||||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
inputs_dict["output_attentions"] = True
|
with self.subTest(model_class.__name__):
|
||||||
inputs_dict["output_hidden_states"] = False
|
inputs_dict["output_attentions"] = True
|
||||||
config.return_dict = True
|
inputs_dict["output_hidden_states"] = False
|
||||||
model = model_class(config)
|
config.return_dict = True
|
||||||
model.to(torch_device)
|
model = model_class(config)
|
||||||
model.eval()
|
model.to(torch_device)
|
||||||
with torch.no_grad():
|
model.eval()
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
with torch.no_grad():
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
# check that output_attentions also work using config
|
# check that output_attentions also work using config
|
||||||
del inputs_dict["output_attentions"]
|
del inputs_dict["output_attentions"]
|
||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
if chunk_length is not None:
|
if chunk_length is not None:
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(attentions[0].shape[-4:]),
|
list(attentions[0].shape[-4:]),
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(attentions[0].shape[-3:]),
|
list(attentions[0].shape[-3:]),
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
out_len = len(outputs)
|
out_len = len(outputs)
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
correct_outlen = 5
|
correct_outlen = 5
|
||||||
|
|
||||||
# loss is at first position
|
# loss is at first position
|
||||||
if "labels" in inputs_dict:
|
if "labels" in inputs_dict:
|
||||||
correct_outlen += 1 # loss is added to beginning
|
correct_outlen += 1 # loss is added to beginning
|
||||||
# Question Answering model returns start_logits and end_logits
|
# Question Answering model returns start_logits and end_logits
|
||||||
if model_class.__name__ in [
|
if model_class.__name__ in [
|
||||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
|
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
|
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
|
||||||
]:
|
]:
|
||||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
if "past_key_values" in outputs:
|
if "past_key_values" in outputs:
|
||||||
correct_outlen += 1 # past_key_values have been returned
|
correct_outlen += 1 # past_key_values have been returned
|
||||||
|
|
||||||
self.assertEqual(out_len, correct_outlen)
|
self.assertEqual(out_len, correct_outlen)
|
||||||
|
|
||||||
# decoder attentions
|
# decoder attentions
|
||||||
decoder_attentions = outputs.decoder_attentions
|
decoder_attentions = outputs.decoder_attentions
|
||||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(decoder_attentions[0].shape[-3:]),
|
list(decoder_attentions[0].shape[-3:]),
|
||||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
# cross attentions
|
# cross attentions
|
||||||
cross_attentions = outputs.cross_attentions
|
cross_attentions = outputs.cross_attentions
|
||||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(cross_attentions[0].shape[-3:]),
|
list(cross_attentions[0].shape[-3:]),
|
||||||
[
|
[
|
||||||
self.model_tester.num_attention_heads,
|
self.model_tester.num_attention_heads,
|
||||||
decoder_seq_length,
|
decoder_seq_length,
|
||||||
encoder_key_length,
|
encoder_key_length,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check attention is always last and order is fine
|
# Check attention is always last and order is fine
|
||||||
inputs_dict["output_attentions"] = True
|
inputs_dict["output_attentions"] = True
|
||||||
@ -1022,30 +1023,31 @@ class ModelTesterMixin:
|
|||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with self.subTest(model_class.__name__):
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||||
elif self.is_encoder_decoder:
|
elif self.is_encoder_decoder:
|
||||||
added_hidden_states = 2
|
added_hidden_states = 2
|
||||||
else:
|
else:
|
||||||
added_hidden_states = 1
|
added_hidden_states = 1
|
||||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||||
|
|
||||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
|
||||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||||
if chunk_length is not None:
|
if chunk_length is not None:
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(self_attentions[0].shape[-4:]),
|
list(self_attentions[0].shape[-4:]),
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(self_attentions[0].shape[-3:]),
|
list(self_attentions[0].shape[-3:]),
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_torchscript_simple(self):
|
def test_torchscript_simple(self):
|
||||||
|
Reference in New Issue
Block a user