Compare commits

...

2 Commits

Author SHA1 Message Date
7fa7508dad Release: v4.43.0 2024-07-23 16:58:49 +02:00
26b179c90d Llama 3.1 conversion 2024-07-23 16:58:49 +02:00
56 changed files with 280 additions and 111 deletions

View File

@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
Array = Any Array = Any
Dataset = datasets.arrow_dataset.Dataset Dataset = datasets.arrow_dataset.Dataset

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk. # Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
Array = Any Array = Any
Dataset = datasets.arrow_dataset.Dataset Dataset = datasets.arrow_dataset.Dataset

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type` # You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version( require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies # region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
task_to_keys = { task_to_keys = {
"cola": ("sentence", None), "cola": ("sentence", None),

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants # region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.43.0.dev0") check_min_version("4.43.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -430,7 +430,7 @@ install_requires = [
setup( setup(
name="transformers", name="transformers",
version="4.43.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="4.43.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co", author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends). # in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.43.0.dev0" __version__ = "4.43.0"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING

View File

@ -129,6 +129,7 @@ def _compute_dynamic_ntk_parameters(
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
""" """
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
if config is not None and len(rope_kwargs) > 0: if config is not None and len(rope_kwargs) > 0:
raise ValueError( raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
@ -249,6 +250,7 @@ def _compute_longrope_parameters(
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin. post-processing scaling factor applied to the computed cos/sin.
""" """
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
# No need to keep BC with longrope, unreleased when this new pattern was created. # No need to keep BC with longrope, unreleased when this new pattern was created.
if len(rope_kwargs) > 0: if len(rope_kwargs) > 0:
raise ValueError( raise ValueError(
@ -293,6 +295,50 @@ def _compute_longrope_parameters(
return inv_freq, attention_factor return inv_freq, attention_factor
def _compute_llama3_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies for llama 3.1.
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
factor = config.rope_scaling["factor"] # `8` in the original implementation
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
return inv_freq, attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature. # parameterizations, as long as the callable has the same signature.
@ -302,6 +348,7 @@ ROPE_INIT_FUNCTIONS = {
"dynamic": _compute_dynamic_ntk_parameters, "dynamic": _compute_dynamic_ntk_parameters,
"yarn": _compute_yarn_parameters, "yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters, "longrope": _compute_longrope_parameters,
"llama3": _compute_llama3_parameters,
} }
@ -339,6 +386,20 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig): def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling["rope_type"]
@ -374,7 +435,8 @@ def _validate_longrope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling["rope_type"]
required_keys = {"rope_type", "short_factor", "long_factor"} required_keys = {"rope_type", "short_factor", "long_factor"}
optional_keys = {"attention_factor", "factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys) _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
@ -417,13 +479,50 @@ def _validate_longrope_parameters(config: PretrainedConfig):
) )
def _validate_llama3_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
if low_freq_factor is None or not isinstance(low_freq_factor, float):
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor:
raise ValueError(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
)
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
raise ValueError(
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
f"{original_max_position_embeddings}"
)
if original_max_position_embeddings >= config.max_position_embeddings:
raise ValueError(
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
)
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
ROPE_VALIDATION_FUNCTIONS = { ROPE_VALIDATION_FUNCTIONS = {
"default": _validate_default_rope_parameters, "default": _validate_default_rope_parameters,
"linear": _validate_linear_scaling_rope_parameters, "linear": _validate_linear_scaling_rope_parameters,
"dynamic": _validate_linear_scaling_rope_parameters, # `dynamic` has the same validation pattern as `linear` "dynamic": _validate_dynamic_scaling_rope_parameters,
"yarn": _validate_yarn_parameters, "yarn": _validate_yarn_parameters,
"longrope": _validate_longrope_parameters, "longrope": _validate_longrope_parameters,
"llama3": _validate_llama3_parameters,
} }

View File

@ -73,25 +73,28 @@ class LlamaConfig(PretrainedConfig):
End of stream token id. End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1): pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
issue](https://github.com/pytorch/pytorch/issues/76232). results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`): tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
to determine which scaling to apply. accordingly.
Expected contents: Expected contents:
`rope_type` (`str`): `rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'], The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
with 'default' being the original RoPE implementation. 'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*): `factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x * most scaling types, a `factor` of x will enable the model to handle sequences of length x *
`max_position_embeddings`. original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*): `attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the computation. If unspecified, it defaults to value recommended by the implementation, using the
@ -104,12 +107,16 @@ class LlamaConfig(PretrainedConfig):
ramp function. If unspecified, it defaults to 1. ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*): `short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (< Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2 size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*): `long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (< Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2 size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`): attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):

View File

@ -17,10 +17,11 @@ import json
import os import os
import shutil import shutil
import warnings import warnings
from typing import List
import torch import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import TikTokenConverter from transformers.convert_slow_tokenizer import TikTokenConverter
@ -85,8 +86,12 @@ NUM_SHARDS = {
"65B": 8, "65B": 8,
"70B": 8, "70B": 8,
"70Bf": 8, "70Bf": 8,
"405B": 8,
"405B-MP16": 16,
} }
CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048}
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
@ -107,9 +112,10 @@ def write_model(
input_base_path, input_base_path,
model_size=None, model_size=None,
safe_serialization=True, safe_serialization=True,
llama_version=1, llama_version="1",
vocab_size=None, vocab_size=None,
num_shards=None, num_shards=None,
instruct=False,
): ):
os.makedirs(model_path, exist_ok=True) os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp") tmp_model_path = os.path.join(model_path, "tmp")
@ -125,18 +131,11 @@ def write_model(
dims_per_head = dim // n_heads dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0) base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
if base > 10000.0 and llama_version != 3: if base > 10000.0 and float(llama_version) < 3:
max_position_embeddings = 16384 max_position_embeddings = 16384
else: else:
# Depending on the Llama version, the default max_position_embeddings has different values. max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version]
if llama_version == 1:
max_position_embeddings = 2048
elif llama_version == 2:
max_position_embeddings = 4096
elif llama_version == 3:
max_position_embeddings = 8192
vocab_size = vocab_size if vocab_size is not None else 32000
if params.get("n_kv_heads", None) is not None: if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_key_value_heads_per_shard = num_key_value_heads // num_shards num_key_value_heads_per_shard = num_key_value_heads // num_shards
@ -144,8 +143,7 @@ def write_model(
else: # compatibility with other checkpoints else: # compatibility with other checkpoints
num_key_value_heads = n_heads num_key_value_heads = n_heads
num_key_value_heads_per_shard = n_heads_per_shard num_key_value_heads_per_shard = n_heads_per_shard
key_value_dim = dims_per_head * num_key_value_heads key_value_dim = dim
print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim)
# permute for sliced rotary # permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim): def permute(w, n_heads, dim1=dim, dim2=dim):
@ -159,11 +157,9 @@ def write_model(
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else: else:
# Sharded # Sharded
loaded = [ checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
torch.load(os.path.join(input_base_path, file), map_location="cpu") print("Loading in order:", checkpoint_list)
for file in sorted(os.listdir(input_base_path)) loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list]
if file.endswith(".pth")
]
param_count = 0 param_count = 0
index_dict = {"weight_map": {}} index_dict = {"weight_map": {}}
for layer_i in range(n_layers): for layer_i in range(n_layers):
@ -263,7 +259,7 @@ def write_model(
"lm_head.weight": loaded["output.weight"], "lm_head.weight": loaded["output.weight"],
} }
else: else:
concat_dim = 0 if llama_version == 3 else 1 concat_dim = 0 if llama_version in ["3", "3.1"] else 1
state_dict = { state_dict = {
"model.norm.weight": loaded[0]["norm.weight"], "model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat( "model.embed_tokens.weight": torch.cat(
@ -282,6 +278,18 @@ def write_model(
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
multiple_of = params["multiple_of"] if "multiple_of" in params else 256 multiple_of = params["multiple_of"] if "multiple_of" in params else 256
if llama_version in ["3", "3.1"]:
bos_token_id = 128000
if instruct:
eos_token_id = [128001, 128008, 128009]
else:
eos_token_id = 128001
else:
bos_token_id = 1
eos_token_id = 2
config = LlamaConfig( config = LlamaConfig(
hidden_size=dim, hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
@ -292,11 +300,21 @@ def write_model(
vocab_size=vocab_size, vocab_size=vocab_size,
rope_theta=base, rope_theta=base,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
bos_token_id=128000 if llama_version == 3 else 1, bos_token_id=bos_token_id,
eos_token_id=128001 if llama_version == 3 else 2, eos_token_id=eos_token_id,
) )
config.save_pretrained(tmp_model_path) config.save_pretrained(tmp_model_path)
if instruct:
generation_config = GenerationConfig(
do_sample=True,
temperature=0.6,
top_p=0.9,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
)
generation_config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now. # Make space so we can load the model properly now.
del state_dict del state_dict
del loaded del loaded
@ -313,7 +331,7 @@ def write_model(
class Llama3Converter(TikTokenConverter): class Llama3Converter(TikTokenConverter):
def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs): def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs):
super().__init__(vocab_file, **kwargs) super().__init__(vocab_file, **kwargs)
tokenizer = self.converted() tokenizer = self.converted()
chat_template = ( chat_template = (
@ -327,34 +345,24 @@ class Llama3Converter(TikTokenConverter):
"{% endfor %}" "{% endfor %}"
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
) )
num_reserved_special_tokens = 256
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)]
tokenizer.add_special_tokens(special_tokens) tokenizer.add_special_tokens(special_tokens)
self.tokenizer = PreTrainedTokenizerFast( self.tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer, tokenizer_object=tokenizer,
bos_token="<|begin_of_text|>", bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>", eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
chat_template=chat_template, chat_template=chat_template if instruct else None,
model_input_names=["input_ids", "attention_mask"], model_input_names=["input_ids", "attention_mask"],
model_max_length=model_max_length,
) )
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2): def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False):
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if llama_version == 3: if llama_version in ["3", "3.1"]:
tokenizer = Llama3Converter(input_tokenizer_path).tokenizer tokenizer = Llama3Converter(
input_tokenizer_path, special_tokens, instruct, model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version]
).tokenizer
else: else:
tokenizer = tokenizer_class(input_tokenizer_path) tokenizer = tokenizer_class(input_tokenizer_path)
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
@ -362,6 +370,37 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2):
return tokenizer return tokenizer
DEFAULT_LLAMA_SPECIAL_TOKENS = {
"3": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
]
+ [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)],
"3.1": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|reserved_special_token_2|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
]
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
}
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -383,9 +422,9 @@ def main():
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. # Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument( parser.add_argument(
"--llama_version", "--llama_version",
choices=[1, 2, 3], choices=["1", "2", "3", "3.1"],
default=1, default="1",
type=int, type=str,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size", help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
) )
parser.add_argument( parser.add_argument(
@ -394,11 +433,34 @@ def main():
type=int, type=int,
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
) )
parser.add_argument(
"--special_tokens",
default=None,
type=List[str],
help="The list of special tokens that should be added to the model.",
)
parser.add_argument(
"--instruct",
default=False,
type=bool,
help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.",
)
args = parser.parse_args() args = parser.parse_args()
if args.model_size is None and args.num_shards is None: if args.model_size is None and args.num_shards is None:
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`") raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
if args.special_tokens is None:
args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS[str(args.llama_version)]
spm_path = os.path.join(args.input_dir, "tokenizer.model") spm_path = os.path.join(args.input_dir, "tokenizer.model")
vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version)) vocab_size = len(
write_tokenizer(
args.output_dir,
spm_path,
llama_version=args.llama_version,
special_tokens=args.special_tokens,
instruct=args.instruct,
)
)
if args.model_size != "tokenizer_only": if args.model_size != "tokenizer_only":
write_model( write_model(
model_path=args.output_dir, model_path=args.output_dir,
@ -408,6 +470,7 @@ def main():
llama_version=args.llama_version, llama_version=args.llama_version,
vocab_size=vocab_size, vocab_size=vocab_size,
num_shards=args.num_shards, num_shards=args.num_shards,
instruct=args.instruct,
) )