Compare commits

...

9 Commits

Author SHA1 Message Date
38d94bffa6 Patch release 2024-07-24 17:42:52 +02:00
b4a0442dbd Fix float8_e4m3fn in modeling_utils (#32193)
* Fix float8_e4m3fn in modeling_utils

* style

* fix

* comment
2024-07-24 17:42:35 +02:00
4672b4d79b Fix resize embedding with Deepspeed (#32192)
fix resize when deepspeed
2024-07-24 17:42:28 +02:00
a2b6a001c0 let's not warn when someone is running a forward (#32176)
* let's not warn when someone is running a foward without cache + self.training

* more models

* fixup
2024-07-24 17:42:19 +02:00
64a90d72a8 RoPE: relaxed rope validation (#32182)
* relaxed rope check

* lets also accept rope_type=None, defaulting to the original implementation

* type and rope_type can coexist
2024-07-24 17:42:14 +02:00
782bfffb2e Patch release 2024-07-23 17:46:57 +02:00
cf0534913f fix 2024-07-23 17:46:42 +02:00
7fa7508dad Release: v4.43.0 2024-07-23 16:58:49 +02:00
26b179c90d Llama 3.1 conversion 2024-07-23 16:58:49 +02:00
74 changed files with 419 additions and 164 deletions

View File

@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)

View File

@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)

View File

@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = logging.getLogger(__name__)

View File

@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = logging.getLogger(__name__)

View File

@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0")
check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -430,7 +430,7 @@ install_requires = [
setup(
name="transformers",
version="4.43.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.43.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.43.0.dev0"
__version__ = "4.43.2"
from typing import TYPE_CHECKING

View File

@ -129,6 +129,7 @@ def _compute_dynamic_ntk_parameters(
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
@ -249,6 +250,7 @@ def _compute_longrope_parameters(
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
# No need to keep BC with longrope, unreleased when this new pattern was created.
if len(rope_kwargs) > 0:
raise ValueError(
@ -293,6 +295,50 @@ def _compute_longrope_parameters(
return inv_freq, attention_factor
def _compute_llama3_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies for llama 3.1.
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
factor = config.rope_scaling["factor"] # `8` in the original implementation
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
return inv_freq, attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
@ -302,11 +348,17 @@ ROPE_INIT_FUNCTIONS = {
"dynamic": _compute_dynamic_ntk_parameters,
"yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters,
"llama3": _compute_llama3_parameters,
}
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's gracefully handle it
if "rope_type" not in received_keys and "type" in received_keys:
received_keys -= {"type"}
received_keys.add("rope_type")
missing_keys = required_keys - received_keys
if missing_keys:
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
@ -314,14 +366,14 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
if optional_keys is not None:
unused_keys = received_keys - required_keys - optional_keys
else:
unused_keys = received_keys - received_keys
unused_keys = received_keys - required_keys
if unused_keys:
raise KeyError(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
@ -329,19 +381,33 @@ def _validate_default_rope_parameters(config: PretrainedConfig):
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
received_keys = set(rope_scaling.keys())
@ -349,22 +415,22 @@ def _validate_yarn_parameters(config: PretrainedConfig):
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
raise ValueError(
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
beta_fast = rope_scaling.get("beta_fast")
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
beta_slow = rope_scaling.get("beta_slow")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
if (beta_fast or 32) < (beta_slow or 1):
raise ValueError(
logger.warning(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)
@ -372,9 +438,10 @@ def _validate_yarn_parameters(config: PretrainedConfig):
def _validate_longrope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"}
optional_keys = {"attention_factor", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
@ -383,15 +450,15 @@ def _validate_longrope_parameters(config: PretrainedConfig):
short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
raise ValueError(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
if not len(short_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
long_factor = rope_scaling.get("long_factor")
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
raise ValueError(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
if not len(long_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
@ -406,24 +473,61 @@ def _validate_longrope_parameters(config: PretrainedConfig):
else:
factor = rope_scaling.get("factor")
if factor is None:
raise ValueError("Missing required keys in `rope_scaling`: 'factor'")
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
elif not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
def _validate_llama3_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
if low_freq_factor is None or not isinstance(low_freq_factor, float):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor:
logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
)
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
f"{original_max_position_embeddings}"
)
if original_max_position_embeddings >= config.max_position_embeddings:
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
)
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
ROPE_VALIDATION_FUNCTIONS = {
"default": _validate_default_rope_parameters,
"linear": _validate_linear_scaling_rope_parameters,
"dynamic": _validate_linear_scaling_rope_parameters, # `dynamic` has the same validation pattern as `linear`
"dynamic": _validate_dynamic_scaling_rope_parameters,
"yarn": _validate_yarn_parameters,
"longrope": _validate_longrope_parameters,
"llama3": _validate_llama3_parameters,
}
@ -435,17 +539,12 @@ def rope_config_validation(config: PretrainedConfig):
if rope_scaling is None:
return
possible_rope_types = set(ROPE_INIT_FUNCTIONS.keys())
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
if rope_type is None:
raise ValueError(
f"rope_scaling must contain a non-None 'rope_type' field. Possible options are {possible_rope_types}"
)
# BC: "rope_type" was originally "type"
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config)
else:
raise ValueError(
logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)

View File

@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model(
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model(
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn:
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and any(
@ -2131,13 +2134,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Replace weights in old_embeddings and return to maintain the same embedding type.
# This ensures correct functionality when a Custom Embedding class is passed as input.
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
# will be set to `None` in the resized embeddings.
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
# will be set to `None` in the resized embeddings.
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
else:
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
return old_embeddings

View File

@ -769,7 +769,9 @@ class CohereModel(CoherePreTrainedModel):
past_seen_tokens = 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@ -1004,7 +1004,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -474,7 +474,9 @@ class GemmaModel(LlamaModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@ -769,7 +769,9 @@ class GemmaModel(GemmaPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
@ -794,7 +796,9 @@ class GemmaModel(GemmaPreTrainedModel):
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -978,7 +978,9 @@ class JetMoeModel(JetMoePreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@ -73,25 +73,28 @@ class LlamaConfig(PretrainedConfig):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
to determine which scaling to apply.
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
with 'default' being the original RoPE implementation.
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
`max_position_embeddings`.
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
@ -104,12 +107,16 @@ class LlamaConfig(PretrainedConfig):
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
@ -182,6 +189,9 @@ class LlamaConfig(PretrainedConfig):
self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(

View File

@ -17,10 +17,11 @@ import json
import os
import shutil
import warnings
from typing import List
import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import TikTokenConverter
@ -85,8 +86,12 @@ NUM_SHARDS = {
"65B": 8,
"70B": 8,
"70Bf": 8,
"405B": 8,
"405B-MP16": 16,
}
CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048}
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
@ -107,9 +112,10 @@ def write_model(
input_base_path,
model_size=None,
safe_serialization=True,
llama_version=1,
llama_version="1",
vocab_size=None,
num_shards=None,
instruct=False,
):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
@ -125,18 +131,11 @@ def write_model(
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
if base > 10000.0 and llama_version != 3:
if base > 10000.0 and float(llama_version) < 3:
max_position_embeddings = 16384
else:
# Depending on the Llama version, the default max_position_embeddings has different values.
if llama_version == 1:
max_position_embeddings = 2048
elif llama_version == 2:
max_position_embeddings = 4096
elif llama_version == 3:
max_position_embeddings = 8192
max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version]
vocab_size = vocab_size if vocab_size is not None else 32000
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_key_value_heads_per_shard = num_key_value_heads // num_shards
@ -144,8 +143,7 @@ def write_model(
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_key_value_heads_per_shard = n_heads_per_shard
key_value_dim = dims_per_head * num_key_value_heads
print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim)
key_value_dim = dim
# permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim):
@ -159,11 +157,9 @@ def write_model(
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, file), map_location="cpu")
for file in sorted(os.listdir(input_base_path))
if file.endswith(".pth")
]
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
print("Loading in order:", checkpoint_list)
loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list]
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
@ -263,7 +259,7 @@ def write_model(
"lm_head.weight": loaded["output.weight"],
}
else:
concat_dim = 0 if llama_version == 3 else 1
concat_dim = 0 if llama_version in ["3", "3.1"] else 1
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
@ -282,6 +278,18 @@ def write_model(
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
if llama_version in ["3", "3.1"]:
bos_token_id = 128000
if instruct:
eos_token_id = [128001, 128008, 128009]
else:
eos_token_id = 128001
else:
bos_token_id = 1
eos_token_id = 2
config = LlamaConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
@ -292,11 +300,21 @@ def write_model(
vocab_size=vocab_size,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
bos_token_id=128000 if llama_version == 3 else 1,
eos_token_id=128001 if llama_version == 3 else 2,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
)
config.save_pretrained(tmp_model_path)
if instruct:
generation_config = GenerationConfig(
do_sample=True,
temperature=0.6,
top_p=0.9,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
)
generation_config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
@ -313,7 +331,7 @@ def write_model(
class Llama3Converter(TikTokenConverter):
def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs):
def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs):
super().__init__(vocab_file, **kwargs)
tokenizer = self.converted()
chat_template = (
@ -327,34 +345,24 @@ class Llama3Converter(TikTokenConverter):
"{% endfor %}"
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
)
num_reserved_special_tokens = 256
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)]
tokenizer.add_special_tokens(special_tokens)
self.tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>",
chat_template=chat_template,
eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
chat_template=chat_template if instruct else None,
model_input_names=["input_ids", "attention_mask"],
model_max_length=model_max_length,
)
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2):
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False):
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if llama_version == 3:
tokenizer = Llama3Converter(input_tokenizer_path).tokenizer
if llama_version in ["3", "3.1"]:
tokenizer = Llama3Converter(
input_tokenizer_path, special_tokens, instruct, model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version]
).tokenizer
else:
tokenizer = tokenizer_class(input_tokenizer_path)
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
@ -362,6 +370,37 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2):
return tokenizer
DEFAULT_LLAMA_SPECIAL_TOKENS = {
"3": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
]
+ [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)],
"3.1": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|reserved_special_token_2|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
]
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -383,9 +422,9 @@ def main():
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--llama_version",
choices=[1, 2, 3],
default=1,
type=int,
choices=["1", "2", "3", "3.1"],
default="1",
type=str,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
)
parser.add_argument(
@ -394,11 +433,34 @@ def main():
type=int,
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
)
parser.add_argument(
"--special_tokens",
default=None,
type=List[str],
help="The list of special tokens that should be added to the model.",
)
parser.add_argument(
"--instruct",
default=False,
type=bool,
help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.",
)
args = parser.parse_args()
if args.model_size is None and args.num_shards is None:
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
if args.special_tokens is None:
args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS[str(args.llama_version)]
spm_path = os.path.join(args.input_dir, "tokenizer.model")
vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version))
vocab_size = len(
write_tokenizer(
args.output_dir,
spm_path,
llama_version=args.llama_version,
special_tokens=args.special_tokens,
instruct=args.instruct,
)
)
if args.model_size != "tokenizer_only":
write_model(
model_path=args.output_dir,
@ -408,6 +470,7 @@ def main():
llama_version=args.llama_version,
vocab_size=vocab_size,
num_shards=args.num_shards,
instruct=args.instruct,
)

View File

@ -107,7 +107,7 @@ class LlamaRotaryEmbedding(nn.Module):
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling["type"])
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
@ -893,7 +893,9 @@ class LlamaModel(LlamaPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -757,7 +757,7 @@ class MistralModel(MistralPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_legacy_cache = True
logger.warning_once(

View File

@ -959,7 +959,7 @@ class MixtralModel(MixtralPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -810,7 +810,9 @@ class OlmoModel(OlmoPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -626,7 +626,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -908,7 +908,7 @@ class PhiModel(PhiPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -949,7 +949,7 @@ class Phi3Model(Phi3PreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -807,7 +807,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -969,7 +969,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -901,7 +901,7 @@ class StableLmModel(StableLmPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -783,7 +783,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@ -526,6 +526,60 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)
def test_model_loading_old_rope_configs(self):
def _reinitialize_config(base_config, new_kwargs):
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
# steps.
base_config_dict = base_config.to_dict()
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
return new_config
# from untouched config -> ✅
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
original_model = LlamaForCausalLM(base_config).to(torch_device)
original_model(**model_inputs)
# from a config with the expected rope configuration -> ✅
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
config = _reinitialize_config(
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
)
self.assertTrue(config.rope_scaling["type"] == "linear")
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("factor field", logs.output[0])
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
)
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("Unrecognized keys", logs.output[0])
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes