Compare commits

...

67 Commits

Author SHA1 Message Date
a87cbe7f81 Woops 2023-03-07 16:43:33 +01:00
78f0063bd0 Add missing info 2023-03-04 21:54:13 +01:00
1a8c34307a Woops 2023-03-04 21:42:07 +01:00
3211e28ad4 Make style 2023-03-04 21:27:35 +01:00
073ef25c24 Add complex implementation 2023-03-04 21:16:12 +01:00
7b851d0036 Keep QKV fused 2023-03-04 20:59:31 +01:00
06caa46f7b Unfuse QKV 2023-03-04 20:43:00 +01:00
67e9476c8d Fix division 2023-03-04 20:30:12 +01:00
861f772c3b No need for this part anymore 2023-03-04 19:37:33 +01:00
e5c0bd9e38 Woops 2023-03-04 19:36:58 +01:00
e6adec6d24 Woops 2023-03-04 19:30:23 +01:00
d8a1c26f74 Maybe this will help 2023-03-04 19:26:10 +01:00
e239ca29d9 WIP 2023-03-04 11:38:31 +01:00
e1b1462334 Revert "Faster RMSNorm"
This reverts commit 55c552571334ad71cc298cee88bc7f21522c6f05.
2023-03-03 18:30:33 +01:00
55c5525713 Faster RMSNorm 2023-03-03 18:05:00 +01:00
c3bb84680f Add a sanity check 2023-03-03 17:34:17 +01:00
c2740fdf78 Woops 2023-03-03 17:30:18 +01:00
2a3ea85ddc float16 weight 2023-03-03 17:24:19 +01:00
de550aa4a7 This seems wrong 2023-03-03 17:22:12 +01:00
4ff878a905 Woops 2023-03-03 17:07:49 +01:00
e17ed99fee Missing save_vocabulary 2023-03-03 17:03:10 +01:00
39a92f0b8e More fixes 2023-03-03 17:01:01 +01:00
8eaa036417 More fixes 2023-03-03 16:57:35 +01:00
783063a9df Different implementation of rotary embeddings 2023-03-03 16:52:05 +01:00
5bbef4284a Update conversion script 2023-03-03 16:15:24 +01:00
49e5f420b5 Make style 2023-03-03 15:32:09 +01:00
deb914fc56 Add tokenizer tests 2023-03-03 15:28:12 +01:00
0c4c821e8f Woops 2023-03-03 15:05:18 +01:00
0ffcb9c265 Make style 2023-03-03 14:57:50 +01:00
c343154509 Fix tokenizer 2023-03-03 14:57:06 +01:00
dc37c4da4d Remove fast tokenizer 2023-03-03 14:31:45 +01:00
5dc2ac759c Try adding a fast tokenizer 2023-03-03 14:30:34 +01:00
db473901cf Fix test 2023-03-03 11:43:24 +01:00
62709c8c3a Fix test 2023-03-03 11:42:25 +01:00
092061868a Make apply_rotary_pos_emb contiguous 2023-03-03 11:32:44 +01:00
d0e439995f Change configuration default to tie_word_embeddings is False 2023-03-03 11:14:56 +01:00
fc0eb91991 Make style 2023-03-02 19:00:02 +01:00
c9f3ff99b7 Add License 2023-03-02 18:59:24 +01:00
84c2f3f36f Try a better tokenization system 2023-03-02 18:29:38 +01:00
1cceb9c548 remove tied word embeddings 2023-03-02 18:18:26 +01:00
db3b574a42 Woops 2023-03-02 18:05:07 +01:00
4c621407b1 Woops 2023-03-02 17:36:21 +01:00
b93226cbe1 Woops 2023-03-02 17:32:45 +01:00
f17a1b8f59 Woops 2023-03-02 17:22:09 +01:00
84ece4ae7e Woops 2023-03-02 17:14:13 +01:00
0c8318e2ec Woops 2023-03-02 17:12:15 +01:00
554b81217b Woops 2023-03-02 17:10:39 +01:00
3bb813c087 Attention mask has to be bool 2023-03-02 17:08:32 +01:00
b4f3cd004c Head size for rotary embeddings 2023-03-02 16:56:44 +01:00
bc91db32ac Woops 2023-03-02 16:52:11 +01:00
3886b86f7b Woops 2023-03-02 16:46:47 +01:00
fad287577c Fix tokenizer 2023-03-02 16:41:16 +01:00
96ef4bdac3 Woops 2023-03-02 16:28:53 +01:00
3be1e4aba9 Test something on the tokenizer 2023-03-02 16:27:44 +01:00
f1b25bb912 Remove license waiting for clear understanding of the license we need to use 2023-03-02 15:31:30 +01:00
da7f13f044 Transformers has a complex config parsing mechanism 2023-03-02 15:24:47 +01:00
54fd14eb46 Transformers has a complex config parsing mechanism 2023-03-02 15:22:58 +01:00
bf65481dbb Woops 2023-03-02 15:07:47 +01:00
e68c8c6059 Get faster init 2023-03-02 15:06:13 +01:00
b630847276 Woops 2023-03-02 15:00:41 +01:00
8c5e7c5748 Got indexing mixed between row linear and column linear 2023-03-02 14:53:45 +01:00
74c1280328 More woops 2023-03-02 14:49:46 +01:00
9a9cd3d5ee Woops 2023-03-02 14:00:04 +01:00
3b217176c5 Parse as Path 2023-03-02 13:54:43 +01:00
1a73aef67e Woops 2023-03-02 13:52:51 +01:00
19ac5d05a0 WIP 2023-03-02 13:37:37 +01:00
2ea20d0652 WIP 2023-03-02 12:01:26 +01:00
15 changed files with 1676 additions and 1 deletions

View File

@ -0,0 +1,48 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LLaMa
## Overview
The LLaMa model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
*<INSERT PAPER ABSTRACT HERE>*
Tips:
<INSERT TIPS ABOUT MODEL HERE>
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
## LLaMaConfig
[[autodoc]] LLaMaConfig
## LLaMaTokenizer
[[autodoc]] LLaMaTokenizer
## LLaMaModel
[[autodoc]] LLaMaModel
- forward
## LLaMaForCausalLM
[[autodoc]] LLaMaForCausalLM
- forward

View File

@ -346,6 +346,7 @@ _import_structure = {
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
"models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"],
"models.lilt": ["LILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LiltConfig"],
"models.llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LLaMaConfig"],
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
"models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"],
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"],
@ -664,6 +665,7 @@ else:
_import_structure["models.fnet"].append("FNetTokenizer")
_import_structure["models.gpt_sw3"].append("GPTSw3Tokenizer")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizer")
_import_structure["models.llama"].append("LLaMaTokenizer")
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
@ -1795,6 +1797,14 @@ else:
"LiltPreTrainedModel",
]
)
_import_structure["models.llama"].extend(
[
"LLaMaForCausalLM",
"LLaMaLayer",
"LLaMaModel",
"LLaMaPreTrainedModel",
]
)
_import_structure["models.longformer"].extend(
[
"LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -3939,6 +3949,7 @@ if TYPE_CHECKING:
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
from .models.lilt import LILT_PRETRAINED_CONFIG_ARCHIVE_MAP, LiltConfig
from .models.llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LLaMaConfig
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer
@ -4231,6 +4242,7 @@ if TYPE_CHECKING:
from .models.fnet import FNetTokenizer
from .models.gpt_sw3 import GPTSw3Tokenizer
from .models.layoutxlm import LayoutXLMTokenizer
from .models.llama import LLaMaTokenizer
from .models.m2m_100 import M2M100Tokenizer
from .models.marian import MarianTokenizer
from .models.mbart import MBart50Tokenizer, MBartTokenizer
@ -5158,6 +5170,12 @@ if TYPE_CHECKING:
LiltModel,
LiltPreTrainedModel,
)
from .models.llama import (
LLaMaForCausalLM,
LLaMaLayer,
LLaMaModel,
LLaMaPreTrainedModel,
)
from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM,

View File

@ -1199,7 +1199,7 @@ class GenerationMixin:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
"You have modified the pretrvained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"

View File

@ -101,6 +101,7 @@ from . import (
led,
levit,
lilt,
llama,
longformer,
longt5,
luke,

View File

@ -107,6 +107,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("led", "LEDConfig"),
("levit", "LevitConfig"),
("lilt", "LiltConfig"),
("llama", "LLaMaConfig"),
("longformer", "LongformerConfig"),
("longt5", "LongT5Config"),
("luke", "LukeConfig"),
@ -281,6 +282,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("levit", "LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lilt", "LILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("llama", "LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longt5", "LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -457,6 +459,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("led", "LED"),
("levit", "LeViT"),
("lilt", "LiLT"),
("llama", "LLaMa"),
("longformer", "Longformer"),
("longt5", "LongT5"),
("luke", "LUKE"),

View File

@ -106,6 +106,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("led", "LEDModel"),
("levit", "LevitModel"),
("lilt", "LiltModel"),
("llama", "LLaMaModel"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@ -295,6 +296,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("led", "LEDForConditionalGeneration"),
("llama", "LLaMaForCausalLM"),
("longformer", "LongformerForMaskedLM"),
("longt5", "LongT5ForConditionalGeneration"),
("luke", "LukeForMaskedLM"),
@ -359,6 +361,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt_neox", "GPTNeoXForCausalLM"),
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"),
("llama", "LLaMaForCausalLM"),
("marian", "MarianForCausalLM"),
("mbart", "MBartForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),

View File

@ -167,6 +167,7 @@ else:
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
("llama", ("LLaMaTokenizer", None)),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
(
"longt5",

View File

@ -0,0 +1,73 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ... import is_sentencepiece_available
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
from ...utils import OptionalDependencyNotAvailable
_import_structure = {"configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LLaMaConfig"]}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_llama"] = ["LLaMaTokenizer"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_llama"] = [
"LLaMaForCausalLM",
"LLaMaLayer",
"LLaMaModel",
"LLaMaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LLaMaConfig
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_llama import LLaMaTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_llama import (
LLaMaForCausalLM,
LLaMaLayer,
LLaMaModel,
LLaMaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,103 @@
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMa model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/llama": "https://huggingface.co/facebook/llama/resolve/main/config.json",
}
class LLaMaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LLaMaModel`]. It is used to instantiate an LLaMa
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMa architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50432):
Vocabulary size of the LLaMa model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LLaMaModel`].
hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 44):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
Example:
```python
>>> from transformers import LLaMaConfig, LLaMaModel
>>> # Initializing a LLaMa gpt-neox-20b style configuration
>>> configuration = LLaMaConfig()
>>> # Initializing a model (with random weights) from the gpt-neox-20b style configuration
>>> model = LLaMaModel(configuration) # doctest: +SKIP
>>> # Accessing the model configuration
>>> configuration = model.config # doctest: +SKIP
```"""
model_type = "llama"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
intermediate_size=11008,
# TODO @thomasw21: I don't thing we need this at all.
max_position_embeddings=2048,
layer_norm_eps=1e-5,
use_cache=True,
# Test fails is we don't provide these values.
initializer_range=0.02,
**kwargs,
):
if "tie_word_embeddings" in kwargs:
assert kwargs["tie_word_embeddings"] is False, "LLaMa doesn't have tied embeddings layer"
else:
# Make sure that we set it at False
kwargs["tie_word_embeddings"] = False
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range

View File

@ -0,0 +1,185 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import re
from pathlib import Path
import torch
from torch import nn
from transformers import LLaMaConfig, LLaMaForCausalLM, LLaMaTokenizer
def generate_config(params_json_path: Path, vocab_size: int) -> LLaMaConfig:
with open(params_json_path, "r") as fi:
hyperparameters = json.load(fi)
assert hyperparameters["vocab_size"] == -1, "We get vocab size information from the tokenizer"
assert vocab_size > 0
hidden_size = hyperparameters["dim"]
multiple_of = hyperparameters["multiple_of"]
intermediate_size = int(2 * 4 * hidden_size / 3)
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
return LLaMaConfig(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=hyperparameters["n_layers"],
num_attention_heads=hyperparameters["n_heads"],
intermediate_size=intermediate_size,
torch_dtype=torch.float16,
max_position_embeddings=2048,
layer_norm_eps=hyperparameters["norm_eps"],
tie_word_embeddings=False,
use_cache=True,
)
def get_tokenzier(tokenizer_path: Path) -> LLaMaTokenizer:
return LLaMaTokenizer(str(tokenizer_path.absolute()))
original_name_to_transformers_name = {
r"^tok_embeddings.weight$": "llama.embed.weight",
r"^norm.weight$": "llama.final_layer_norm.weight",
r"^output.weight$": "lm_head.weight",
r"^layers.(\d*).attention_norm.weight$": r"llama.layers.\1.attention_norm.weight",
r"^layers.(\d*).attention.wq.weight$": r"llama.layers.\1.attention.qkv.weight",
r"^layers.(\d*).attention.wk.weight$": r"llama.layers.\1.attention.qkv.weight",
r"^layers.(\d*).attention.wv.weight$": r"llama.layers.\1.attention.qkv.weight",
r"^layers.(\d*).attention.wo.weight$": r"llama.layers.\1.attention.o.weight",
r"^layers.(\d*).ffn_norm.weight$": r"llama.layers.\1.ff_norm.weight",
r"^layers.(\d*).feed_forward.w1.weight$": r"llama.layers.\1.ff.wi_0.weight",
r"^layers.(\d*).feed_forward.w2.weight$": r"llama.layers.\1.ff.wo.weight",
r"^layers.(\d*).feed_forward.w3.weight$": r"llama.layers.\1.ff.wi_1.weight",
}
def map_original_names_to_transformers_names(original_name: str):
for pattern, repl in original_name_to_transformers_name.items():
if re.match(pattern, original_name) is None:
continue
return re.sub(pattern, repl, original_name)
raise ValueError(f"Did not expect {original_name}")
@torch.no_grad()
def convert_model(model_path: Path, config: LLaMaConfig) -> LLaMaForCausalLM:
# HACK @thomasw21: Bypasses `reset_parameters` which can be quite costly.
nn.Linear.reset_parameters = lambda *args: None
model = LLaMaForCausalLM(config=config).to(config.torch_dtype)
print(f"Saving weights in {config.torch_dtype}")
paths = sorted(model_path.glob("*.pth"))
tp_size = len(paths)
hf_param_set = set(model.state_dict().keys())
checkpoint_param_set = set()
for tp_rank, path in enumerate(paths):
weights = torch.load(path)
for original_name, original_param in weights.items():
if original_name.endswith(".attention.inner_attention.rope.freqs"):
print(f"We ignore {original_name} as it stores the rotary embeddings which are not in fact parameters")
continue
transformers_name = map_original_names_to_transformers_names(original_name)
transformers_param = model.get_parameter(transformers_name)
checkpoint_param_set.add(transformers_name)
assert (
original_param.dtype == transformers_param.dtype
), f"Expected dtypes to match. Got {original_param.dtype} and {transformers_param.dtype}"
if original_name.endswith("norm.weight"):
transformers_param.copy_(original_param)
continue
# weights are sharded across TP
if any(
original_name.endswith(suffix)
for suffix in [".feed_forward.w2.weight", ".attention.wo.weight", "tok_embeddings.weight"]
):
# Row Linear weight
input_dim = transformers_param.shape[1]
assert input_dim % tp_size == 0
step = input_dim // tp_size
start = tp_rank * step
end = (tp_rank + 1) * step
transformers_param[:, start:end].copy_(original_param)
continue
# Column linear
if any(original_name.endswith(suffix) for suffix in [".wq.weight", ".wk.weight", "wv.weight"]):
# We fuse all the weights into a single qkv matrix.
index, suffix = [
(i, suffix)
for i, suffix in enumerate([".wq.weight", ".wk.weight", "wv.weight"])
if original_name.endswith(suffix)
][0]
assert config.num_attention_heads % tp_size == 0
heads_per_tp_rank = config.num_attention_heads // tp_size
transformer_shard = transformers_param.view(
config.num_attention_heads, 3, config.hidden_size // config.num_attention_heads, config.hidden_size
)[tp_rank * heads_per_tp_rank : (tp_rank + 1) * heads_per_tp_rank, index]
original_param = original_param.view(*transformer_shard.shape)
else:
output_dim = transformers_param.shape[0]
assert output_dim % tp_size == 0
step = output_dim // tp_size
start = tp_rank * step
end = (tp_rank + 1) * step
transformer_shard = transformers_param[start:end]
transformer_shard.copy_(original_param)
assert (
hf_param_set == checkpoint_param_set
), f"Updated params didn't match. Ref: {hf_param_set}, got: {checkpoint_param_set}"
return model
def main(args):
tokenizer = get_tokenzier(tokenizer_path=args.checkpoint_directory / "tokenizer.model")
model_path = args.checkpoint_directory / args.model_subpath
config = generate_config(model_path / "params.json", vocab_size=tokenizer.vocab_size)
model = convert_model(model_path=model_path, config=config)
config.save_pretrained(args.pytorch_dump_folder_path)
model.save_pretrained(args.pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint-directory",
type=Path,
required=True,
help="Path to the checkpoint path containing `tokenizer.json` and different model size checkpoints.",
)
parser.add_argument(
"--pytorch-dump-folder-path", type=Path, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--model-subpath",
type=Path,
required=True,
help="Subpath after going into checkpoint directory where the model checkpoint lies. Typically `7B` or `13B`",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,669 @@
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMa model."""
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_llama import LLaMaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LLaMaConfig"
# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LLaMa
class LLaMaLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""Construct a RMSNorm"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# LLaMaLayerNorm uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class LLaMaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LLaMaConfig
base_model_prefix = "llama"
supports_gradient_checkpointing = True
_no_split_modules = ["LLaMaLayer"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LLaMaLayerNorm):
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LLaMaModel):
module.gradient_checkpointing = value
class LLaMaAttention(nn.Module):
def __init__(self, config: LLaMaConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
max_positions = config.max_position_embeddings
self.register_buffer(
"causal_mask",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
self.o = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
def forward(
self,
hidden_states,
attention_mask,
complex_freq,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.qkv(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape)
# [batch, seq_len, num_attention_heads, 3, head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# Compute token offset for rotary embeddings (when decoding)
offset = 0
if has_layer_past:
offset = layer_past[0].shape[-2]
query = apply_rotary_pos_emb(embedding=query, complex_freq=complex_freq, offset=offset)
key = apply_rotary_pos_emb(embedding=key, complex_freq=complex_freq, offset=offset)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
attn_output = self.o(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.shape
key_length = key.shape[-2]
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
# TODO @thomasw21: Use `baddbmm` in order to fuse the kernels together. This comes with a loss of precision compared to original inference code
attn_scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(self.head_size)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
# Build attention mask
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
if attention_mask is not None:
attention_mask = causal_mask * attention_mask
else:
attention_mask = causal_mask
if attn_scores.dtype == torch.float16:
attn_scores = attn_scores.to(torch.float)
attn_scores = torch.masked_fill(attn_scores, ~attention_mask, torch.finfo(attn_scores.dtype).min)
attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.max_seq_len_cached = max_position_embeddings
self.build_new_freq(length=max_position_embeddings, device=None)
def build_new_freq(self, length, device):
assert self.dim % 2 == 0
assert self.max_seq_len_cached <= length
self.max_seq_len_cached = length
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device)[: self.dim // 2] / self.dim)
)
self.device = self.inv_freq.device
# Build here to make `torch.jit.trace` work.
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# We don't register as a buffer as this needs to be kept in fp32 at all time.
self.complex_freq = torch.polar(torch.ones((1,), device=device, dtype=torch.float), freqs)
def forward(self, device, seq_len=None):
if seq_len > self.max_seq_len_cached or self.device != device:
self.build_new_freq(length=max(self.max_seq_len_cached, seq_len), device=device)
return self.complex_freq[:seq_len, ...]
def apply_rotary_pos_emb(embedding, complex_freq, offset: int = 0):
complex_freq = complex_freq[..., offset : embedding.shape[-2] + offset, :]
# q[...,::2] is considered the real part, q[...,1::2] is the imaginary part
assert complex_freq.dtype == torch.complex64
assert embedding.shape[-1] % 2 == 0
complex_embed = torch.view_as_complex(embedding.float().view(*embedding.shape[:-1], embedding.shape[-1] // 2, 2))
complex_embed_rot = complex_embed * complex_freq
embed_rot = torch.view_as_real(complex_embed_rot).view(embedding.shape)
return embed_rot.type_as(embedding)
class LLaMaFF(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.wi_1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, hidden_states):
return self.wo(F.silu(self.wi_0(hidden_states)) * self.wi_1(hidden_states))
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer with GPTNeoX->LLaMa
class LLaMaLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention_norm = LLaMaLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = LLaMaAttention(config)
self.ff_norm = LLaMaLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.ff = LLaMaFF(config)
def forward(
self,
hidden_states,
complex_freq,
attention_mask=None,
head_mask=None,
use_cache=False,
layer_past=None,
output_attentions=False,
):
attention_layer_outputs = self.attention(
hidden_states=self.attention_norm(hidden_states),
complex_freq=complex_freq,
attention_mask=attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
attn_output = attn_output + hidden_states
ff_output = self.ff(self.ff_norm(attn_output))
hidden_states = ff_output + attn_output
if use_cache:
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs
LLAMA_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`~LLaMaConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare LLaMa Model transformer outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel with GPTNeoX->LLaMa,GPT_NEOX->LLAMA
class LLaMaModel(LLaMaPreTrainedModel):
def __init__(self, config: LLaMaConfig):
super().__init__(config)
self.config = config
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([LLaMaLayer(config) for _ in range(config.num_hidden_layers)])
head_size = config.hidden_size // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(head_size, config.max_position_embeddings)
self.final_layer_norm = LLaMaLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values = tuple([None] * self.config.num_hidden_layers)
# Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :].to(torch.bool)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.embed(input_ids)
# Compute token offset for rotary embeddings (when decoding)
all_seq_length = seq_length
if past_key_values[0] is not None:
all_seq_length += past_key_values[0][0].shape[-2]
complex_freq = self.rotary_emb(device=inputs_embeds.device, seq_len=all_seq_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
complex_freq,
attention_mask,
head_mask[i],
)
else:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
complex_freq=complex_freq,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
@add_start_docstrings(
"""LLaMa Model with a `language modeling` head on top for CLM fine-tuning.""", LLAMA_START_DOCSTRING
)
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM with GPTNeoX->LLaMa,GPT_NEOX->LLAMA,gpt_neox->llama
class LLaMaForCausalLM(LLaMaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.llama = LLaMaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, value):
self.lm_head = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LLaMaForCausalLM, LLaMaConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/llama")
>>> config = LLaMaConfig.from_pretrained("facebook/llama")
>>> model = LLaMaForCausalLM.from_pretrained("facebook/llama")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.llama(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
}
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

View File

@ -0,0 +1,182 @@
# coding=utf-8
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMa."""
import os
from shutil import copyfile
from typing import List, Optional, Tuple, Union
import sentencepiece as spm
from ... import AddedToken, PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
class LLaMaTokenizer(PreTrainedTokenizer):
"""
Construct a "LLaMa", Based on [SentencePiece](https://github.com/google/sentencepiece).
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
bos_token (`str`, *optional*, defaults to `"<bos>"`):
The end of sequence token.
eos_token (`str`, *optional*, defaults to `"<eos>"`):
The beginning of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
Attributes:
sp_model (`SentencePieceProcessor`):
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
padding_side: str = "left"
def __init__(
self, vocab_file: str, bos_token: str = "<s>", eos_token: str = "</s>", unk_token: str = "<unk>", **kwargs
) -> None:
self.sp_model = spm.SentencePieceProcessor(model_file=vocab_file)
# TODO @thomasw21: Understand if I need to have <bos> and such since they are not part of the official LLaMa model
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
# TODO @thomasw21: Why the fuck is that `-1`?
# pad_token=self.sp_model.pad_id(),
**kwargs,
)
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(vocab_file)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
# <bos>/<eos>/<unk> already have ids in LLaMa tokenizer
new_tokens = [tok for tok in new_tokens if tok not in [self.bos_token, self.eos_token, self.unk_token]]
return super()._add_tokens(new_tokens=new_tokens, special_tokens=special_tokens)
def bos_token_id(self) -> Optional[int]:
result = self.sp_model.bos_id()
if result >= 0:
return result
else:
return None
def eos_token_id(self) -> Optional[int]:
result = self.sp_model.eos_id()
if result >= 0:
return result
else:
return None
def unk_token_id(self) -> Optional[int]:
result = self.sp_model.unk_id()
if result >= 0:
return result
else:
return None
@property
def vocab_size(self):
return self.sp_model.get_piece_size()
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
return vocab
def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token: str):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.sp_model.IdToPiece(index)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A sequence has the following format:
- single sequence: `<bos> X`
- pair of sequences: `<bos> A B`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
result = [self.sp_model.bos_id()] + token_ids_0
if token_ids_1 is not None:
result += token_ids_1
return result
def convert_tokens_to_string(self, tokens: List[str]):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

View File

@ -0,0 +1,234 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch LLaMa model. """
import unittest
from transformers import LLaMaConfig, is_torch_available
from transformers.testing_utils import require_torch, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import LLaMaForCausalLM, LLaMaModel
class LLaMaModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_labels = None
if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
config = self.get_config()
return config, input_ids, input_mask, token_labels
def get_config(self):
return LLaMaConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
# TODO @thomasw21: I don't why this was set to False by default
is_decoder=True,
initializer_range=self.initializer_range,
)
def prepare_config_and_inputs_for_decoder(self):
config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
config.is_decoder = True
return config, input_ids, input_mask, token_labels
def create_and_check_model(self, config, input_ids, input_mask):
model = LLaMaModel(config=config)
model.to(torch_device)
model.eval()
_ = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(self, config, input_ids, input_mask):
config.add_cross_attention = True
model = LLaMaModel(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels):
model = LLaMaForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
config.is_decoder = True
model = LLaMaForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True)
output_from_no_past = output_from_no_past["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask, token_labels = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class LLaMaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (LLaMaModel, LLaMaForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (LLaMaForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": LLaMaModel, "text-generation": LLaMaForCausalLM} if is_torch_available() else {}
)
test_pruning = False
test_missing_keys = False
test_model_parallel = False
test_head_masking = False
def setUp(self):
self.model_tester = LLaMaModelTester(self)
self.config_tester = ConfigTester(self, config_class=LLaMaConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(config, input_ids, input_mask)
def test_model_as_decoder(self):
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
input_mask = None
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
def test_decoder_model_past_large_inputs(self):
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask)
def test_model_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass

View File

@ -0,0 +1,155 @@
# coding=utf-8
# Copyright 2018 Google LLaMa Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import SPIECE_UNDERLINE, LLaMaTokenizer
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
from transformers.utils import is_tf_available, is_torch_available
from ...test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece
@require_tokenizers
class LLaMaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LLaMaTokenizer
test_sentencepiece = True
test_rust_tokenizer = False
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = LLaMaTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def test_convert_token_and_id(self):
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
token = "<s>"
token_id = 1
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "<unk>")
self.assertEqual(vocab_keys[1], "<s>")
# TODO @thomasw21: LLaMa tokenizer doesn't have pad token
# self.assertEqual(vocab_keys[-1], "<pad>")
self.assertEqual(len(vocab_keys), 1_101)
def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 1_100)
def test_full_tokenizer(self):
tokenizer = LLaMaTokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
def get_tokenizer(self, **kwargs) -> LLaMaTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs)
def test_bos_token_in_text_id_considered_as_text(self):
tokenizer = self.get_tokenizer()
tokenized_string_bos = tokenizer(["<s>"])
self.assertListEqual(
tokenized_string_bos["input_ids"],
[
1,
],
)
def test_tokenized_text_always_starts_with_bos_token(self):
tokenizer = self.get_tokenizer()
tokenized_texts = tokenizer(
["<s>", "", "Hello my name is John.", 'If you want to use bos token, add "<s>" in your text input.']
)
for tokenized_text in tokenized_texts["input_ids"]:
self.assertGreaterEqual(len(tokenized_text), 1)
self.assertEqual(tokenized_text[0], tokenizer.bos_token_id())
self.assertNotIn(tokenizer.bos_token_id(), tokenized_text[1:])