mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
67 Commits
v4.43.3
...
thomas/lla
Author | SHA1 | Date | |
---|---|---|---|
a87cbe7f81 | |||
78f0063bd0 | |||
1a8c34307a | |||
3211e28ad4 | |||
073ef25c24 | |||
7b851d0036 | |||
06caa46f7b | |||
67e9476c8d | |||
861f772c3b | |||
e5c0bd9e38 | |||
e6adec6d24 | |||
d8a1c26f74 | |||
e239ca29d9 | |||
e1b1462334 | |||
55c5525713 | |||
c3bb84680f | |||
c2740fdf78 | |||
2a3ea85ddc | |||
de550aa4a7 | |||
4ff878a905 | |||
e17ed99fee | |||
39a92f0b8e | |||
8eaa036417 | |||
783063a9df | |||
5bbef4284a | |||
49e5f420b5 | |||
deb914fc56 | |||
0c4c821e8f | |||
0ffcb9c265 | |||
c343154509 | |||
dc37c4da4d | |||
5dc2ac759c | |||
db473901cf | |||
62709c8c3a | |||
092061868a | |||
d0e439995f | |||
fc0eb91991 | |||
c9f3ff99b7 | |||
84c2f3f36f | |||
1cceb9c548 | |||
db3b574a42 | |||
4c621407b1 | |||
b93226cbe1 | |||
f17a1b8f59 | |||
84ece4ae7e | |||
0c8318e2ec | |||
554b81217b | |||
3bb813c087 | |||
b4f3cd004c | |||
bc91db32ac | |||
3886b86f7b | |||
fad287577c | |||
96ef4bdac3 | |||
3be1e4aba9 | |||
f1b25bb912 | |||
da7f13f044 | |||
54fd14eb46 | |||
bf65481dbb | |||
e68c8c6059 | |||
b630847276 | |||
8c5e7c5748 | |||
74c1280328 | |||
9a9cd3d5ee | |||
3b217176c5 | |||
1a73aef67e | |||
19ac5d05a0 | |||
2ea20d0652 |
48
docs/source/en/model_doc/llama.mdx
Normal file
48
docs/source/en/model_doc/llama.mdx
Normal file
@ -0,0 +1,48 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LLaMa
|
||||
|
||||
## Overview
|
||||
|
||||
The LLaMa model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*<INSERT PAPER ABSTRACT HERE>*
|
||||
|
||||
Tips:
|
||||
|
||||
<INSERT TIPS ABOUT MODEL HERE>
|
||||
|
||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
||||
|
||||
|
||||
## LLaMaConfig
|
||||
|
||||
[[autodoc]] LLaMaConfig
|
||||
|
||||
## LLaMaTokenizer
|
||||
|
||||
[[autodoc]] LLaMaTokenizer
|
||||
|
||||
## LLaMaModel
|
||||
|
||||
[[autodoc]] LLaMaModel
|
||||
- forward
|
||||
|
||||
## LLaMaForCausalLM
|
||||
|
||||
[[autodoc]] LLaMaForCausalLM
|
||||
- forward
|
@ -346,6 +346,7 @@ _import_structure = {
|
||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||
"models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"],
|
||||
"models.lilt": ["LILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LiltConfig"],
|
||||
"models.llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LLaMaConfig"],
|
||||
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
||||
"models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"],
|
||||
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"],
|
||||
@ -664,6 +665,7 @@ else:
|
||||
_import_structure["models.fnet"].append("FNetTokenizer")
|
||||
_import_structure["models.gpt_sw3"].append("GPTSw3Tokenizer")
|
||||
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizer")
|
||||
_import_structure["models.llama"].append("LLaMaTokenizer")
|
||||
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
|
||||
_import_structure["models.marian"].append("MarianTokenizer")
|
||||
_import_structure["models.mbart"].append("MBartTokenizer")
|
||||
@ -1795,6 +1797,14 @@ else:
|
||||
"LiltPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.llama"].extend(
|
||||
[
|
||||
"LLaMaForCausalLM",
|
||||
"LLaMaLayer",
|
||||
"LLaMaModel",
|
||||
"LLaMaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.longformer"].extend(
|
||||
[
|
||||
"LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -3939,6 +3949,7 @@ if TYPE_CHECKING:
|
||||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
|
||||
from .models.lilt import LILT_PRETRAINED_CONFIG_ARCHIVE_MAP, LiltConfig
|
||||
from .models.llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LLaMaConfig
|
||||
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
|
||||
from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config
|
||||
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer
|
||||
@ -4231,6 +4242,7 @@ if TYPE_CHECKING:
|
||||
from .models.fnet import FNetTokenizer
|
||||
from .models.gpt_sw3 import GPTSw3Tokenizer
|
||||
from .models.layoutxlm import LayoutXLMTokenizer
|
||||
from .models.llama import LLaMaTokenizer
|
||||
from .models.m2m_100 import M2M100Tokenizer
|
||||
from .models.marian import MarianTokenizer
|
||||
from .models.mbart import MBart50Tokenizer, MBartTokenizer
|
||||
@ -5158,6 +5170,12 @@ if TYPE_CHECKING:
|
||||
LiltModel,
|
||||
LiltPreTrainedModel,
|
||||
)
|
||||
from .models.llama import (
|
||||
LLaMaForCausalLM,
|
||||
LLaMaLayer,
|
||||
LLaMaModel,
|
||||
LLaMaPreTrainedModel,
|
||||
)
|
||||
from .models.longformer import (
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LongformerForMaskedLM,
|
||||
|
@ -1199,7 +1199,7 @@ class GenerationMixin:
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
"You have modified the pretrvained model configuration to control generation. This is a"
|
||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||
" Please use a generation configuration file (see"
|
||||
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
|
||||
|
@ -101,6 +101,7 @@ from . import (
|
||||
led,
|
||||
levit,
|
||||
lilt,
|
||||
llama,
|
||||
longformer,
|
||||
longt5,
|
||||
luke,
|
||||
|
@ -107,6 +107,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("led", "LEDConfig"),
|
||||
("levit", "LevitConfig"),
|
||||
("lilt", "LiltConfig"),
|
||||
("llama", "LLaMaConfig"),
|
||||
("longformer", "LongformerConfig"),
|
||||
("longt5", "LongT5Config"),
|
||||
("luke", "LukeConfig"),
|
||||
@ -281,6 +282,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("levit", "LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("lilt", "LILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("llama", "LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longt5", "LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -457,6 +459,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("led", "LED"),
|
||||
("levit", "LeViT"),
|
||||
("lilt", "LiLT"),
|
||||
("llama", "LLaMa"),
|
||||
("longformer", "Longformer"),
|
||||
("longt5", "LongT5"),
|
||||
("luke", "LUKE"),
|
||||
|
@ -106,6 +106,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("led", "LEDModel"),
|
||||
("levit", "LevitModel"),
|
||||
("lilt", "LiltModel"),
|
||||
("llama", "LLaMaModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("longt5", "LongT5Model"),
|
||||
("luke", "LukeModel"),
|
||||
@ -295,6 +296,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("ibert", "IBertForMaskedLM"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("llama", "LLaMaForCausalLM"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("longt5", "LongT5ForConditionalGeneration"),
|
||||
("luke", "LukeForMaskedLM"),
|
||||
@ -359,6 +361,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("llama", "LLaMaForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
("mbart", "MBartForCausalLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
|
@ -167,6 +167,7 @@ else:
|
||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llama", ("LLaMaTokenizer", None)),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"longt5",
|
||||
|
73
src/transformers/models/llama/__init__.py
Normal file
73
src/transformers/models/llama/__init__.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ... import is_sentencepiece_available
|
||||
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
|
||||
from ...utils import OptionalDependencyNotAvailable
|
||||
|
||||
|
||||
_import_structure = {"configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LLaMaConfig"]}
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tokenization_llama"] = ["LLaMaTokenizer"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_llama"] = [
|
||||
"LLaMaForCausalLM",
|
||||
"LLaMaLayer",
|
||||
"LLaMaModel",
|
||||
"LLaMaPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LLaMaConfig
|
||||
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tokenization_llama import LLaMaTokenizer
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_llama import (
|
||||
LLaMaForCausalLM,
|
||||
LLaMaLayer,
|
||||
LLaMaModel,
|
||||
LLaMaPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
103
src/transformers/models/llama/configuration_llama.py
Normal file
103
src/transformers/models/llama/configuration_llama.py
Normal file
@ -0,0 +1,103 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" LLaMa model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/llama": "https://huggingface.co/facebook/llama/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class LLaMaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LLaMaModel`]. It is used to instantiate an LLaMa
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the LLaMa architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50432):
|
||||
Vocabulary size of the LLaMa model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`LLaMaModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 6144):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 44):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import LLaMaConfig, LLaMaModel
|
||||
|
||||
>>> # Initializing a LLaMa gpt-neox-20b style configuration
|
||||
>>> configuration = LLaMaConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the gpt-neox-20b style configuration
|
||||
>>> model = LLaMaModel(configuration) # doctest: +SKIP
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config # doctest: +SKIP
|
||||
```"""
|
||||
model_type = "llama"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
intermediate_size=11008,
|
||||
# TODO @thomasw21: I don't thing we need this at all.
|
||||
max_position_embeddings=2048,
|
||||
layer_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
# Test fails is we don't provide these values.
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
if "tie_word_embeddings" in kwargs:
|
||||
assert kwargs["tie_word_embeddings"] is False, "LLaMa doesn't have tied embeddings layer"
|
||||
else:
|
||||
# Make sure that we set it at False
|
||||
kwargs["tie_word_embeddings"] = False
|
||||
super().__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import LLaMaConfig, LLaMaForCausalLM, LLaMaTokenizer
|
||||
|
||||
|
||||
def generate_config(params_json_path: Path, vocab_size: int) -> LLaMaConfig:
|
||||
with open(params_json_path, "r") as fi:
|
||||
hyperparameters = json.load(fi)
|
||||
|
||||
assert hyperparameters["vocab_size"] == -1, "We get vocab size information from the tokenizer"
|
||||
assert vocab_size > 0
|
||||
|
||||
hidden_size = hyperparameters["dim"]
|
||||
multiple_of = hyperparameters["multiple_of"]
|
||||
intermediate_size = int(2 * 4 * hidden_size / 3)
|
||||
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
||||
|
||||
return LLaMaConfig(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=hyperparameters["n_layers"],
|
||||
num_attention_heads=hyperparameters["n_heads"],
|
||||
intermediate_size=intermediate_size,
|
||||
torch_dtype=torch.float16,
|
||||
max_position_embeddings=2048,
|
||||
layer_norm_eps=hyperparameters["norm_eps"],
|
||||
tie_word_embeddings=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
|
||||
def get_tokenzier(tokenizer_path: Path) -> LLaMaTokenizer:
|
||||
return LLaMaTokenizer(str(tokenizer_path.absolute()))
|
||||
|
||||
|
||||
original_name_to_transformers_name = {
|
||||
r"^tok_embeddings.weight$": "llama.embed.weight",
|
||||
r"^norm.weight$": "llama.final_layer_norm.weight",
|
||||
r"^output.weight$": "lm_head.weight",
|
||||
r"^layers.(\d*).attention_norm.weight$": r"llama.layers.\1.attention_norm.weight",
|
||||
r"^layers.(\d*).attention.wq.weight$": r"llama.layers.\1.attention.qkv.weight",
|
||||
r"^layers.(\d*).attention.wk.weight$": r"llama.layers.\1.attention.qkv.weight",
|
||||
r"^layers.(\d*).attention.wv.weight$": r"llama.layers.\1.attention.qkv.weight",
|
||||
r"^layers.(\d*).attention.wo.weight$": r"llama.layers.\1.attention.o.weight",
|
||||
r"^layers.(\d*).ffn_norm.weight$": r"llama.layers.\1.ff_norm.weight",
|
||||
r"^layers.(\d*).feed_forward.w1.weight$": r"llama.layers.\1.ff.wi_0.weight",
|
||||
r"^layers.(\d*).feed_forward.w2.weight$": r"llama.layers.\1.ff.wo.weight",
|
||||
r"^layers.(\d*).feed_forward.w3.weight$": r"llama.layers.\1.ff.wi_1.weight",
|
||||
}
|
||||
|
||||
|
||||
def map_original_names_to_transformers_names(original_name: str):
|
||||
for pattern, repl in original_name_to_transformers_name.items():
|
||||
if re.match(pattern, original_name) is None:
|
||||
continue
|
||||
return re.sub(pattern, repl, original_name)
|
||||
raise ValueError(f"Did not expect {original_name}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_model(model_path: Path, config: LLaMaConfig) -> LLaMaForCausalLM:
|
||||
# HACK @thomasw21: Bypasses `reset_parameters` which can be quite costly.
|
||||
nn.Linear.reset_parameters = lambda *args: None
|
||||
model = LLaMaForCausalLM(config=config).to(config.torch_dtype)
|
||||
print(f"Saving weights in {config.torch_dtype}")
|
||||
|
||||
paths = sorted(model_path.glob("*.pth"))
|
||||
tp_size = len(paths)
|
||||
hf_param_set = set(model.state_dict().keys())
|
||||
checkpoint_param_set = set()
|
||||
for tp_rank, path in enumerate(paths):
|
||||
weights = torch.load(path)
|
||||
|
||||
for original_name, original_param in weights.items():
|
||||
if original_name.endswith(".attention.inner_attention.rope.freqs"):
|
||||
print(f"We ignore {original_name} as it stores the rotary embeddings which are not in fact parameters")
|
||||
continue
|
||||
|
||||
transformers_name = map_original_names_to_transformers_names(original_name)
|
||||
transformers_param = model.get_parameter(transformers_name)
|
||||
checkpoint_param_set.add(transformers_name)
|
||||
assert (
|
||||
original_param.dtype == transformers_param.dtype
|
||||
), f"Expected dtypes to match. Got {original_param.dtype} and {transformers_param.dtype}"
|
||||
|
||||
if original_name.endswith("norm.weight"):
|
||||
transformers_param.copy_(original_param)
|
||||
continue
|
||||
|
||||
# weights are sharded across TP
|
||||
if any(
|
||||
original_name.endswith(suffix)
|
||||
for suffix in [".feed_forward.w2.weight", ".attention.wo.weight", "tok_embeddings.weight"]
|
||||
):
|
||||
# Row Linear weight
|
||||
input_dim = transformers_param.shape[1]
|
||||
assert input_dim % tp_size == 0
|
||||
step = input_dim // tp_size
|
||||
start = tp_rank * step
|
||||
end = (tp_rank + 1) * step
|
||||
transformers_param[:, start:end].copy_(original_param)
|
||||
continue
|
||||
|
||||
# Column linear
|
||||
if any(original_name.endswith(suffix) for suffix in [".wq.weight", ".wk.weight", "wv.weight"]):
|
||||
# We fuse all the weights into a single qkv matrix.
|
||||
index, suffix = [
|
||||
(i, suffix)
|
||||
for i, suffix in enumerate([".wq.weight", ".wk.weight", "wv.weight"])
|
||||
if original_name.endswith(suffix)
|
||||
][0]
|
||||
assert config.num_attention_heads % tp_size == 0
|
||||
heads_per_tp_rank = config.num_attention_heads // tp_size
|
||||
transformer_shard = transformers_param.view(
|
||||
config.num_attention_heads, 3, config.hidden_size // config.num_attention_heads, config.hidden_size
|
||||
)[tp_rank * heads_per_tp_rank : (tp_rank + 1) * heads_per_tp_rank, index]
|
||||
original_param = original_param.view(*transformer_shard.shape)
|
||||
else:
|
||||
output_dim = transformers_param.shape[0]
|
||||
assert output_dim % tp_size == 0
|
||||
step = output_dim // tp_size
|
||||
start = tp_rank * step
|
||||
end = (tp_rank + 1) * step
|
||||
transformer_shard = transformers_param[start:end]
|
||||
|
||||
transformer_shard.copy_(original_param)
|
||||
|
||||
assert (
|
||||
hf_param_set == checkpoint_param_set
|
||||
), f"Updated params didn't match. Ref: {hf_param_set}, got: {checkpoint_param_set}"
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
tokenizer = get_tokenzier(tokenizer_path=args.checkpoint_directory / "tokenizer.model")
|
||||
|
||||
model_path = args.checkpoint_directory / args.model_subpath
|
||||
config = generate_config(model_path / "params.json", vocab_size=tokenizer.vocab_size)
|
||||
model = convert_model(model_path=model_path, config=config)
|
||||
|
||||
config.save_pretrained(args.pytorch_dump_folder_path)
|
||||
model.save_pretrained(args.pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint-directory",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to the checkpoint path containing `tokenizer.json` and different model size checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch-dump-folder-path", type=Path, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-subpath",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Subpath after going into checkpoint directory where the model checkpoint lies. Typically `7B` or `13B`",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
669
src/transformers/models/llama/modeling_llama.py
Normal file
669
src/transformers/models/llama/modeling_llama.py
Normal file
@ -0,0 +1,669 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch LLaMa model."""
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_llama import LLaMaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "LLaMaConfig"
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LLaMa
|
||||
class LLaMaLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""Construct a RMSNorm"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# LLaMaLayerNorm uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class LLaMaPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = LLaMaConfig
|
||||
base_model_prefix = "llama"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LLaMaLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, LLaMaLayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, LLaMaModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class LLaMaAttention(nn.Module):
|
||||
def __init__(self, config: LLaMaConfig):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"causal_mask",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
|
||||
self.o = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
complex_freq,
|
||||
head_mask=None,
|
||||
layer_past=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
has_layer_past = layer_past is not None
|
||||
|
||||
# Compute QKV
|
||||
# Attention heads [batch, seq_len, hidden_size]
|
||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
||||
qkv = self.qkv(hidden_states)
|
||||
|
||||
# [batch, seq_len, (num_heads * 3 * head_size)]
|
||||
# --> [batch, seq_len, num_heads, 3 * head_size]
|
||||
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
|
||||
qkv = qkv.view(*new_qkv_shape)
|
||||
|
||||
# [batch, seq_len, num_attention_heads, 3, head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
|
||||
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
|
||||
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
||||
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
||||
|
||||
# Compute token offset for rotary embeddings (when decoding)
|
||||
offset = 0
|
||||
if has_layer_past:
|
||||
offset = layer_past[0].shape[-2]
|
||||
query = apply_rotary_pos_emb(embedding=query, complex_freq=complex_freq, offset=offset)
|
||||
key = apply_rotary_pos_emb(embedding=key, complex_freq=complex_freq, offset=offset)
|
||||
|
||||
# Cache QKV values
|
||||
if has_layer_past:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = (key, value) if use_cache else None
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
@classmethod
|
||||
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden dim into attn_head_size and num_attention_heads
|
||||
"""
|
||||
# tensor: [bs, seq_len, hidden_size]
|
||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
||||
tensor = tensor.view(new_shape)
|
||||
# -> [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
return tensor
|
||||
|
||||
@classmethod
|
||||
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||
"""
|
||||
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
||||
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
|
||||
# -> [bs, seq_len, hidden_size]
|
||||
return tensor
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
# compute causal mask from causal mask buffer
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.shape
|
||||
key_length = key.shape[-2]
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
|
||||
# TODO @thomasw21: Use `baddbmm` in order to fuse the kernels together. This comes with a loss of precision compared to original inference code
|
||||
attn_scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(self.head_size)
|
||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||
|
||||
# Build attention mask
|
||||
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
||||
if attention_mask is not None:
|
||||
attention_mask = causal_mask * attention_mask
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
|
||||
if attn_scores.dtype == torch.float16:
|
||||
attn_scores = attn_scores.to(torch.float)
|
||||
attn_scores = torch.masked_fill(attn_scores, ~attention_mask, torch.finfo(attn_scores.dtype).min)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings, base=10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.build_new_freq(length=max_position_embeddings, device=None)
|
||||
|
||||
def build_new_freq(self, length, device):
|
||||
assert self.dim % 2 == 0
|
||||
assert self.max_seq_len_cached <= length
|
||||
self.max_seq_len_cached = length
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device)[: self.dim // 2] / self.dim)
|
||||
)
|
||||
self.device = self.inv_freq.device
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# We don't register as a buffer as this needs to be kept in fp32 at all time.
|
||||
self.complex_freq = torch.polar(torch.ones((1,), device=device, dtype=torch.float), freqs)
|
||||
|
||||
def forward(self, device, seq_len=None):
|
||||
if seq_len > self.max_seq_len_cached or self.device != device:
|
||||
self.build_new_freq(length=max(self.max_seq_len_cached, seq_len), device=device)
|
||||
|
||||
return self.complex_freq[:seq_len, ...]
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(embedding, complex_freq, offset: int = 0):
|
||||
complex_freq = complex_freq[..., offset : embedding.shape[-2] + offset, :]
|
||||
# q[...,::2] is considered the real part, q[...,1::2] is the imaginary part
|
||||
assert complex_freq.dtype == torch.complex64
|
||||
assert embedding.shape[-1] % 2 == 0
|
||||
complex_embed = torch.view_as_complex(embedding.float().view(*embedding.shape[:-1], embedding.shape[-1] // 2, 2))
|
||||
complex_embed_rot = complex_embed * complex_freq
|
||||
embed_rot = torch.view_as_real(complex_embed_rot).view(embedding.shape)
|
||||
return embed_rot.type_as(embedding)
|
||||
|
||||
|
||||
class LLaMaFF(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.wi_1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.wo(F.silu(self.wi_0(hidden_states)) * self.wi_1(hidden_states))
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer with GPTNeoX->LLaMa
|
||||
class LLaMaLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention_norm = LLaMaLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = LLaMaAttention(config)
|
||||
self.ff_norm = LLaMaLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.ff = LLaMaFF(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
complex_freq,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
layer_past=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
attention_layer_outputs = self.attention(
|
||||
hidden_states=self.attention_norm(hidden_states),
|
||||
complex_freq=complex_freq,
|
||||
attention_mask=attention_mask,
|
||||
layer_past=layer_past,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||
outputs = attention_layer_outputs[1:]
|
||||
|
||||
attn_output = attn_output + hidden_states
|
||||
ff_output = self.ff(self.ff_norm(attn_output))
|
||||
hidden_states = ff_output + attn_output
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
LLAMA_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`~LLaMaConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `({0})`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
LLAMA_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel with GPTNeoX->LLaMa,GPT_NEOX->LLAMA
|
||||
class LLaMaModel(LLaMaPreTrainedModel):
|
||||
def __init__(self, config: LLaMaConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([LLaMaLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
self.rotary_emb = RotaryEmbedding(head_size, config.max_position_embeddings)
|
||||
self.final_layer_norm = LLaMaLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
output_type=BaseModelOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
assert batch_size > 0, "batch_size has to be defined and > 0"
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :].to(torch.bool)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed(input_ids)
|
||||
|
||||
# Compute token offset for rotary embeddings (when decoding)
|
||||
all_seq_length = seq_length
|
||||
if past_key_values[0] is not None:
|
||||
all_seq_length += past_key_values[0][0].shape[-2]
|
||||
complex_freq = self.rotary_emb(device=inputs_embeds.device, seq_len=all_seq_length)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for layer_past
|
||||
return module(*inputs, use_cache, None, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer),
|
||||
hidden_states,
|
||||
complex_freq,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
complex_freq=complex_freq,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""LLaMa Model with a `language modeling` head on top for CLM fine-tuning.""", LLAMA_START_DOCSTRING
|
||||
)
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM with GPTNeoX->LLaMa,GPT_NEOX->LLAMA,gpt_neox->llama
|
||||
class LLaMaForCausalLM(LLaMaPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.llama = LLaMaModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.lm_head = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
|
||||
only required when the model is used as a decoder in a Sequence to Sequence model.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LLaMaForCausalLM, LLaMaConfig
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/llama")
|
||||
>>> config = LLaMaConfig.from_pretrained("facebook/llama")
|
||||
>>> model = LLaMaForCausalLM.from_pretrained("facebook/llama")
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.llama(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=lm_loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values and past_key_values[0] is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
182
src/transformers/models/llama/tokenization_llama.py
Normal file
182
src/transformers/models/llama/tokenization_llama.py
Normal file
@ -0,0 +1,182 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for LLaMa."""
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ... import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
|
||||
class LLaMaTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a "LLaMa", Based on [SentencePiece](https://github.com/google/sentencepiece).
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
||||
contains the vocabulary necessary to instantiate a tokenizer.
|
||||
bos_token (`str`, *optional*, defaults to `"<bos>"`):
|
||||
The end of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `"<eos>"`):
|
||||
The beginning of sequence token.
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
|
||||
Attributes:
|
||||
sp_model (`SentencePieceProcessor`):
|
||||
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
padding_side: str = "left"
|
||||
|
||||
def __init__(
|
||||
self, vocab_file: str, bos_token: str = "<s>", eos_token: str = "</s>", unk_token: str = "<unk>", **kwargs
|
||||
) -> None:
|
||||
self.sp_model = spm.SentencePieceProcessor(model_file=vocab_file)
|
||||
|
||||
# TODO @thomasw21: Understand if I need to have <bos> and such since they are not part of the official LLaMa model
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
# TODO @thomasw21: Why the fuck is that `-1`?
|
||||
# pad_token=self.sp_model.pad_id(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
self.sp_model = spm.SentencePieceProcessor(vocab_file)
|
||||
|
||||
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
||||
# <bos>/<eos>/<unk> already have ids in LLaMa tokenizer
|
||||
new_tokens = [tok for tok in new_tokens if tok not in [self.bos_token, self.eos_token, self.unk_token]]
|
||||
return super()._add_tokens(new_tokens=new_tokens, special_tokens=special_tokens)
|
||||
|
||||
def bos_token_id(self) -> Optional[int]:
|
||||
result = self.sp_model.bos_id()
|
||||
if result >= 0:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
def eos_token_id(self) -> Optional[int]:
|
||||
result = self.sp_model.eos_id()
|
||||
if result >= 0:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
def unk_token_id(self) -> Optional[int]:
|
||||
result = self.sp_model.unk_id()
|
||||
if result >= 0:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
def get_vocab(self):
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def _convert_token_to_id(self, token: str):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A sequence has the following format:
|
||||
|
||||
- single sequence: `<bos> X`
|
||||
- pair of sequences: `<bos> A B`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
result = [self.sp_model.bos_id()] + token_ids_0
|
||||
if token_ids_1 is not None:
|
||||
result += token_ids_1
|
||||
return result
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string.strip()
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
0
tests/models/llama/__init__.py
Normal file
0
tests/models/llama/__init__.py
Normal file
234
tests/models/llama/test_modeling_llama.py
Normal file
234
tests/models/llama/test_modeling_llama.py
Normal file
@ -0,0 +1,234 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch LLaMa model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import LLaMaConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import LLaMaForCausalLM, LLaMaModel
|
||||
|
||||
|
||||
class LLaMaModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, token_labels
|
||||
|
||||
def get_config(self):
|
||||
return LLaMaConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
# TODO @thomasw21: I don't why this was set to False by default
|
||||
is_decoder=True,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
|
||||
return config, input_ids, input_mask, token_labels
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = LLaMaModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(self, config, input_ids, input_mask):
|
||||
config.add_cross_attention = True
|
||||
model = LLaMaModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels):
|
||||
model = LLaMaForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
|
||||
config.is_decoder = True
|
||||
model = LLaMaForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True)
|
||||
output_from_no_past = output_from_no_past["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask, token_labels = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class LLaMaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (LLaMaModel, LLaMaForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (LLaMaForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": LLaMaModel, "text-generation": LLaMaForCausalLM} if is_torch_available() else {}
|
||||
)
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_model_parallel = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LLaMaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LLaMaConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(config, input_ids, input_mask)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
|
||||
|
||||
def test_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
input_mask = None
|
||||
|
||||
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
|
||||
|
||||
def test_decoder_model_past_large_inputs(self):
|
||||
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask)
|
||||
|
||||
def test_model_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
155
tests/models/llama/test_tokenization_llama.py
Normal file
155
tests/models/llama/test_tokenization_llama.py
Normal file
@ -0,0 +1,155 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google LLaMa Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, LLaMaTokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||
from transformers.utils import is_tf_available, is_torch_available
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
if is_torch_available():
|
||||
FRAMEWORK = "pt"
|
||||
elif is_tf_available():
|
||||
FRAMEWORK = "tf"
|
||||
else:
|
||||
FRAMEWORK = "jax"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class LLaMaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = LLaMaTokenizer
|
||||
test_sentencepiece = True
|
||||
test_rust_tokenizer = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = LLaMaTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def test_convert_token_and_id(self):
|
||||
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
|
||||
token = "<s>"
|
||||
token_id = 1
|
||||
|
||||
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
|
||||
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
|
||||
|
||||
def test_get_vocab(self):
|
||||
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
||||
|
||||
self.assertEqual(vocab_keys[0], "<unk>")
|
||||
self.assertEqual(vocab_keys[1], "<s>")
|
||||
# TODO @thomasw21: LLaMa tokenizer doesn't have pad token
|
||||
# self.assertEqual(vocab_keys[-1], "<pad>")
|
||||
self.assertEqual(len(vocab_keys), 1_101)
|
||||
|
||||
def test_vocab_size(self):
|
||||
self.assertEqual(self.get_tokenizer().vocab_size, 1_100)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = LLaMaTokenizer(SAMPLE_VOCAB)
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"9",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"é",
|
||||
".",
|
||||
],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"<unk>",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"<unk>",
|
||||
".",
|
||||
],
|
||||
)
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> LLaMaTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs)
|
||||
|
||||
def test_bos_token_in_text_id_considered_as_text(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenized_string_bos = tokenizer(["<s>"])
|
||||
self.assertListEqual(
|
||||
tokenized_string_bos["input_ids"],
|
||||
[
|
||||
1,
|
||||
],
|
||||
)
|
||||
|
||||
def test_tokenized_text_always_starts_with_bos_token(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenized_texts = tokenizer(
|
||||
["<s>", "", "Hello my name is John.", 'If you want to use bos token, add "<s>" in your text input.']
|
||||
)
|
||||
for tokenized_text in tokenized_texts["input_ids"]:
|
||||
self.assertGreaterEqual(len(tokenized_text), 1)
|
||||
self.assertEqual(tokenized_text[0], tokenizer.bos_token_id())
|
||||
self.assertNotIn(tokenizer.bos_token_id(), tokenized_text[1:])
|
Reference in New Issue
Block a user