mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
2 Commits
use_new_ti
...
v4.43.0
Author | SHA1 | Date | |
---|---|---|---|
7fa7508dad | |||
26b179c90d |
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -430,7 +430,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.43.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.43.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.43.0.dev0"
|
||||
__version__ = "4.43.0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -129,6 +129,7 @@ def _compute_dynamic_ntk_parameters(
|
||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
||||
"""
|
||||
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
||||
if config is not None and len(rope_kwargs) > 0:
|
||||
raise ValueError(
|
||||
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
||||
@ -249,6 +250,7 @@ def _compute_longrope_parameters(
|
||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||
post-processing scaling factor applied to the computed cos/sin.
|
||||
"""
|
||||
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
||||
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
||||
if len(rope_kwargs) > 0:
|
||||
raise ValueError(
|
||||
@ -293,6 +295,50 @@ def _compute_longrope_parameters(
|
||||
return inv_freq, attention_factor
|
||||
|
||||
|
||||
def _compute_llama3_parameters(
|
||||
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
||||
) -> Tuple["torch.Tensor", float]:
|
||||
"""
|
||||
Computes the inverse frequencies for llama 3.1.
|
||||
|
||||
Args:
|
||||
config ([`~transformers.PretrainedConfig`]):
|
||||
The model configuration.
|
||||
device (`torch.device`):
|
||||
The device to use for initialization of the inverse frequencies.
|
||||
seq_len (`int`, *optional*):
|
||||
The current sequence length. Unused for this type of RoPE.
|
||||
rope_kwargs (`Dict`, *optional*):
|
||||
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
||||
Returns:
|
||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||
post-processing scaling factor applied to the computed cos/sin.
|
||||
"""
|
||||
# Gets the default RoPE parameters
|
||||
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
||||
|
||||
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
||||
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
||||
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
||||
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
new_freqs = []
|
||||
for freq in inv_freq:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / factor)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
|
||||
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
|
||||
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
||||
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
||||
# parameterizations, as long as the callable has the same signature.
|
||||
@ -302,6 +348,7 @@ ROPE_INIT_FUNCTIONS = {
|
||||
"dynamic": _compute_dynamic_ntk_parameters,
|
||||
"yarn": _compute_yarn_parameters,
|
||||
"longrope": _compute_longrope_parameters,
|
||||
"llama3": _compute_llama3_parameters,
|
||||
}
|
||||
|
||||
|
||||
@ -339,6 +386,20 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
required_keys = {"rope_type", "factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"original_max_position_embeddings"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_yarn_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
@ -374,7 +435,8 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
required_keys = {"rope_type", "short_factor", "long_factor"}
|
||||
optional_keys = {"attention_factor", "factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
|
||||
@ -417,13 +479,50 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
def _validate_llama3_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
||||
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
||||
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
||||
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
||||
if high_freq_factor < low_freq_factor:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
||||
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
||||
)
|
||||
|
||||
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
||||
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
||||
f"{original_max_position_embeddings}"
|
||||
)
|
||||
if original_max_position_embeddings >= config.max_position_embeddings:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
||||
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
||||
)
|
||||
|
||||
|
||||
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
||||
ROPE_VALIDATION_FUNCTIONS = {
|
||||
"default": _validate_default_rope_parameters,
|
||||
"linear": _validate_linear_scaling_rope_parameters,
|
||||
"dynamic": _validate_linear_scaling_rope_parameters, # `dynamic` has the same validation pattern as `linear`
|
||||
"dynamic": _validate_dynamic_scaling_rope_parameters,
|
||||
"yarn": _validate_yarn_parameters,
|
||||
"longrope": _validate_longrope_parameters,
|
||||
"llama3": _validate_llama3_parameters,
|
||||
}
|
||||
|
||||
|
||||
|
@ -73,25 +73,28 @@ class LlamaConfig(PretrainedConfig):
|
||||
End of stream token id.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
|
||||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
||||
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
||||
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
|
||||
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
|
||||
to determine which scaling to apply.
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
|
||||
with 'default' being the original RoPE implementation.
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
`max_position_embeddings`.
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
@ -104,12 +107,16 @@ class LlamaConfig(PretrainedConfig):
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
|
@ -17,10 +17,11 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
|
||||
from transformers.convert_slow_tokenizer import TikTokenConverter
|
||||
|
||||
|
||||
@ -85,8 +86,12 @@ NUM_SHARDS = {
|
||||
"65B": 8,
|
||||
"70B": 8,
|
||||
"70Bf": 8,
|
||||
"405B": 8,
|
||||
"405B-MP16": 16,
|
||||
}
|
||||
|
||||
CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048}
|
||||
|
||||
|
||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||
@ -107,9 +112,10 @@ def write_model(
|
||||
input_base_path,
|
||||
model_size=None,
|
||||
safe_serialization=True,
|
||||
llama_version=1,
|
||||
llama_version="1",
|
||||
vocab_size=None,
|
||||
num_shards=None,
|
||||
instruct=False,
|
||||
):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
tmp_model_path = os.path.join(model_path, "tmp")
|
||||
@ -125,18 +131,11 @@ def write_model(
|
||||
dims_per_head = dim // n_heads
|
||||
base = params.get("rope_theta", 10000.0)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
if base > 10000.0 and llama_version != 3:
|
||||
if base > 10000.0 and float(llama_version) < 3:
|
||||
max_position_embeddings = 16384
|
||||
else:
|
||||
# Depending on the Llama version, the default max_position_embeddings has different values.
|
||||
if llama_version == 1:
|
||||
max_position_embeddings = 2048
|
||||
elif llama_version == 2:
|
||||
max_position_embeddings = 4096
|
||||
elif llama_version == 3:
|
||||
max_position_embeddings = 8192
|
||||
max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version]
|
||||
|
||||
vocab_size = vocab_size if vocab_size is not None else 32000
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_key_value_heads_per_shard = num_key_value_heads // num_shards
|
||||
@ -144,8 +143,7 @@ def write_model(
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
num_key_value_heads_per_shard = n_heads_per_shard
|
||||
key_value_dim = dims_per_head * num_key_value_heads
|
||||
print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim)
|
||||
key_value_dim = dim
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
@ -159,11 +157,9 @@ def write_model(
|
||||
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, file), map_location="cpu")
|
||||
for file in sorted(os.listdir(input_base_path))
|
||||
if file.endswith(".pth")
|
||||
]
|
||||
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
|
||||
print("Loading in order:", checkpoint_list)
|
||||
loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list]
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
@ -263,7 +259,7 @@ def write_model(
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
else:
|
||||
concat_dim = 0 if llama_version == 3 else 1
|
||||
concat_dim = 0 if llama_version in ["3", "3.1"] else 1
|
||||
state_dict = {
|
||||
"model.norm.weight": loaded[0]["norm.weight"],
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
@ -282,6 +278,18 @@ def write_model(
|
||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
||||
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
||||
|
||||
if llama_version in ["3", "3.1"]:
|
||||
bos_token_id = 128000
|
||||
|
||||
if instruct:
|
||||
eos_token_id = [128001, 128008, 128009]
|
||||
else:
|
||||
eos_token_id = 128001
|
||||
else:
|
||||
bos_token_id = 1
|
||||
eos_token_id = 2
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||
@ -292,11 +300,21 @@ def write_model(
|
||||
vocab_size=vocab_size,
|
||||
rope_theta=base,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
bos_token_id=128000 if llama_version == 3 else 1,
|
||||
eos_token_id=128001 if llama_version == 3 else 2,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
config.save_pretrained(tmp_model_path)
|
||||
|
||||
if instruct:
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
top_p=0.9,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
generation_config.save_pretrained(tmp_model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
@ -313,7 +331,7 @@ def write_model(
|
||||
|
||||
|
||||
class Llama3Converter(TikTokenConverter):
|
||||
def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs):
|
||||
def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs):
|
||||
super().__init__(vocab_file, **kwargs)
|
||||
tokenizer = self.converted()
|
||||
chat_template = (
|
||||
@ -327,34 +345,24 @@ class Llama3Converter(TikTokenConverter):
|
||||
"{% endfor %}"
|
||||
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
||||
)
|
||||
num_reserved_special_tokens = 256
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|reserved_special_token_3|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|reserved_special_token_4|>",
|
||||
"<|eot_id|>", # end of turn
|
||||
] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)]
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
|
||||
self.tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|end_of_text|>",
|
||||
chat_template=chat_template,
|
||||
eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
|
||||
chat_template=chat_template if instruct else None,
|
||||
model_input_names=["input_ids", "attention_mask"],
|
||||
model_max_length=model_max_length,
|
||||
)
|
||||
|
||||
|
||||
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2):
|
||||
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False):
|
||||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
||||
if llama_version == 3:
|
||||
tokenizer = Llama3Converter(input_tokenizer_path).tokenizer
|
||||
if llama_version in ["3", "3.1"]:
|
||||
tokenizer = Llama3Converter(
|
||||
input_tokenizer_path, special_tokens, instruct, model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version]
|
||||
).tokenizer
|
||||
else:
|
||||
tokenizer = tokenizer_class(input_tokenizer_path)
|
||||
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
||||
@ -362,6 +370,37 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2):
|
||||
return tokenizer
|
||||
|
||||
|
||||
DEFAULT_LLAMA_SPECIAL_TOKENS = {
|
||||
"3": [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|reserved_special_token_3|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|reserved_special_token_4|>",
|
||||
"<|eot_id|>", # end of turn
|
||||
]
|
||||
+ [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)],
|
||||
"3.1": [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|finetune_right_pad_id|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eom_id|>", # end of message
|
||||
"<|eot_id|>", # end of turn
|
||||
"<|python_tag|>",
|
||||
]
|
||||
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -383,9 +422,9 @@ def main():
|
||||
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
||||
parser.add_argument(
|
||||
"--llama_version",
|
||||
choices=[1, 2, 3],
|
||||
default=1,
|
||||
type=int,
|
||||
choices=["1", "2", "3", "3.1"],
|
||||
default="1",
|
||||
type=str,
|
||||
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -394,11 +433,34 @@ def main():
|
||||
type=int,
|
||||
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--special_tokens",
|
||||
default=None,
|
||||
type=List[str],
|
||||
help="The list of special tokens that should be added to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruct",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.model_size is None and args.num_shards is None:
|
||||
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
|
||||
if args.special_tokens is None:
|
||||
args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS[str(args.llama_version)]
|
||||
|
||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||
vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version))
|
||||
vocab_size = len(
|
||||
write_tokenizer(
|
||||
args.output_dir,
|
||||
spm_path,
|
||||
llama_version=args.llama_version,
|
||||
special_tokens=args.special_tokens,
|
||||
instruct=args.instruct,
|
||||
)
|
||||
)
|
||||
if args.model_size != "tokenizer_only":
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
@ -408,6 +470,7 @@ def main():
|
||||
llama_version=args.llama_version,
|
||||
vocab_size=vocab_size,
|
||||
num_shards=args.num_shards,
|
||||
instruct=args.instruct,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user