mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 09:44:02 +08:00
Compare commits
2 Commits
cache-refa
...
v4.43.0
Author | SHA1 | Date | |
---|---|---|---|
7fa7508dad | |||
26b179c90d |
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
Array = Any
|
Array = Any
|
||||||
Dataset = datasets.arrow_dataset.Dataset
|
Dataset = datasets.arrow_dataset.Dataset
|
||||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
Array = Any
|
Array = Any
|
||||||
Dataset = datasets.arrow_dataset.Dataset
|
Dataset = datasets.arrow_dataset.Dataset
|
||||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
# You should update this to your particular problem to have better documentation of `model_type`
|
# You should update this to your particular problem to have better documentation of `model_type`
|
||||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version(
|
require_version(
|
||||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
# region Checking dependencies
|
# region Checking dependencies
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
task_to_keys = {
|
task_to_keys = {
|
||||||
"cola": ("sentence", None),
|
"cola": ("sentence", None),
|
||||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
# region Dependencies and constants
|
# region Dependencies and constants
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.43.0.dev0")
|
check_min_version("4.43.0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -430,7 +430,7 @@ install_requires = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="transformers",
|
name="transformers",
|
||||||
version="4.43.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="4.43.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||||
author_email="transformers@huggingface.co",
|
author_email="transformers@huggingface.co",
|
||||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||||
# in the namespace without actually importing anything (and especially none of the backends).
|
# in the namespace without actually importing anything (and especially none of the backends).
|
||||||
|
|
||||||
__version__ = "4.43.0.dev0"
|
__version__ = "4.43.0"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
@ -129,6 +129,7 @@ def _compute_dynamic_ntk_parameters(
|
|||||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||||
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
||||||
"""
|
"""
|
||||||
|
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
||||||
if config is not None and len(rope_kwargs) > 0:
|
if config is not None and len(rope_kwargs) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
||||||
@ -249,6 +250,7 @@ def _compute_longrope_parameters(
|
|||||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||||
post-processing scaling factor applied to the computed cos/sin.
|
post-processing scaling factor applied to the computed cos/sin.
|
||||||
"""
|
"""
|
||||||
|
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
||||||
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
||||||
if len(rope_kwargs) > 0:
|
if len(rope_kwargs) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -293,6 +295,50 @@ def _compute_longrope_parameters(
|
|||||||
return inv_freq, attention_factor
|
return inv_freq, attention_factor
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_llama3_parameters(
|
||||||
|
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
||||||
|
) -> Tuple["torch.Tensor", float]:
|
||||||
|
"""
|
||||||
|
Computes the inverse frequencies for llama 3.1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config ([`~transformers.PretrainedConfig`]):
|
||||||
|
The model configuration.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to use for initialization of the inverse frequencies.
|
||||||
|
seq_len (`int`, *optional*):
|
||||||
|
The current sequence length. Unused for this type of RoPE.
|
||||||
|
rope_kwargs (`Dict`, *optional*):
|
||||||
|
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
||||||
|
Returns:
|
||||||
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||||
|
post-processing scaling factor applied to the computed cos/sin.
|
||||||
|
"""
|
||||||
|
# Gets the default RoPE parameters
|
||||||
|
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
||||||
|
|
||||||
|
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
||||||
|
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
||||||
|
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
||||||
|
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
new_freqs = []
|
||||||
|
for freq in inv_freq:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
if wavelen < high_freq_wavelen:
|
||||||
|
new_freqs.append(freq)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
new_freqs.append(freq / factor)
|
||||||
|
else:
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
|
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
|
||||||
|
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
|
||||||
|
return inv_freq, attention_factor
|
||||||
|
|
||||||
|
|
||||||
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
||||||
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
||||||
# parameterizations, as long as the callable has the same signature.
|
# parameterizations, as long as the callable has the same signature.
|
||||||
@ -302,6 +348,7 @@ ROPE_INIT_FUNCTIONS = {
|
|||||||
"dynamic": _compute_dynamic_ntk_parameters,
|
"dynamic": _compute_dynamic_ntk_parameters,
|
||||||
"yarn": _compute_yarn_parameters,
|
"yarn": _compute_yarn_parameters,
|
||||||
"longrope": _compute_longrope_parameters,
|
"longrope": _compute_longrope_parameters,
|
||||||
|
"llama3": _compute_llama3_parameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -339,6 +386,20 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
|||||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
||||||
|
rope_scaling = config.rope_scaling
|
||||||
|
rope_type = rope_scaling["rope_type"]
|
||||||
|
required_keys = {"rope_type", "factor"}
|
||||||
|
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||||
|
optional_keys = {"original_max_position_embeddings"}
|
||||||
|
received_keys = set(rope_scaling.keys())
|
||||||
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||||
|
|
||||||
|
factor = rope_scaling["factor"]
|
||||||
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_yarn_parameters(config: PretrainedConfig):
|
def _validate_yarn_parameters(config: PretrainedConfig):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling["rope_type"]
|
rope_type = rope_scaling["rope_type"]
|
||||||
@ -374,7 +435,8 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
|||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling["rope_type"]
|
rope_type = rope_scaling["rope_type"]
|
||||||
required_keys = {"rope_type", "short_factor", "long_factor"}
|
required_keys = {"rope_type", "short_factor", "long_factor"}
|
||||||
optional_keys = {"attention_factor", "factor"}
|
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||||
|
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||||
|
|
||||||
@ -417,13 +479,50 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_llama3_parameters(config: PretrainedConfig):
|
||||||
|
rope_scaling = config.rope_scaling
|
||||||
|
rope_type = rope_scaling["rope_type"]
|
||||||
|
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
||||||
|
received_keys = set(rope_scaling.keys())
|
||||||
|
_check_received_keys(rope_type, received_keys, required_keys)
|
||||||
|
|
||||||
|
factor = rope_scaling["factor"]
|
||||||
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
|
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||||
|
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||||
|
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
||||||
|
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
||||||
|
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
||||||
|
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
||||||
|
if high_freq_factor < low_freq_factor:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
||||||
|
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
||||||
|
)
|
||||||
|
|
||||||
|
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
||||||
|
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
||||||
|
f"{original_max_position_embeddings}"
|
||||||
|
)
|
||||||
|
if original_max_position_embeddings >= config.max_position_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
||||||
|
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
||||||
ROPE_VALIDATION_FUNCTIONS = {
|
ROPE_VALIDATION_FUNCTIONS = {
|
||||||
"default": _validate_default_rope_parameters,
|
"default": _validate_default_rope_parameters,
|
||||||
"linear": _validate_linear_scaling_rope_parameters,
|
"linear": _validate_linear_scaling_rope_parameters,
|
||||||
"dynamic": _validate_linear_scaling_rope_parameters, # `dynamic` has the same validation pattern as `linear`
|
"dynamic": _validate_dynamic_scaling_rope_parameters,
|
||||||
"yarn": _validate_yarn_parameters,
|
"yarn": _validate_yarn_parameters,
|
||||||
"longrope": _validate_longrope_parameters,
|
"longrope": _validate_longrope_parameters,
|
||||||
|
"llama3": _validate_llama3_parameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,25 +73,28 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
End of stream token id.
|
End of stream token id.
|
||||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
|
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
||||||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
||||||
issue](https://github.com/pytorch/pytorch/issues/76232).
|
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to tie weight embeddings
|
Whether to tie weight embeddings
|
||||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
The base period of the RoPE embeddings.
|
The base period of the RoPE embeddings.
|
||||||
rope_scaling (`Dict`, *optional*):
|
rope_scaling (`Dict`, *optional*):
|
||||||
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
|
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||||
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
|
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||||
to determine which scaling to apply.
|
accordingly.
|
||||||
Expected contents:
|
Expected contents:
|
||||||
`rope_type` (`str`):
|
`rope_type` (`str`):
|
||||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
|
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||||
with 'default' being the original RoPE implementation.
|
'llama3'], with 'default' being the original RoPE implementation.
|
||||||
`factor` (`float`, *optional*):
|
`factor` (`float`, *optional*):
|
||||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||||
`max_position_embeddings`.
|
original maximum pre-trained length.
|
||||||
|
`original_max_position_embeddings` (`int`, *optional*):
|
||||||
|
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||||
|
pretraining.
|
||||||
`attention_factor` (`float`, *optional*):
|
`attention_factor` (`float`, *optional*):
|
||||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||||
@ -104,12 +107,16 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
ramp function. If unspecified, it defaults to 1.
|
ramp function. If unspecified, it defaults to 1.
|
||||||
`short_factor` (`List[float]`, *optional*):
|
`short_factor` (`List[float]`, *optional*):
|
||||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||||
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
size divided by the number of attention heads divided by 2
|
size divided by the number of attention heads divided by 2
|
||||||
`long_factor` (`List[float]`, *optional*):
|
`long_factor` (`List[float]`, *optional*):
|
||||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||||
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
size divided by the number of attention heads divided by 2
|
size divided by the number of attention heads divided by 2
|
||||||
|
`low_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||||
|
`high_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
@ -17,10 +17,11 @@ import json
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
|
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
|
||||||
from transformers.convert_slow_tokenizer import TikTokenConverter
|
from transformers.convert_slow_tokenizer import TikTokenConverter
|
||||||
|
|
||||||
|
|
||||||
@ -85,8 +86,12 @@ NUM_SHARDS = {
|
|||||||
"65B": 8,
|
"65B": 8,
|
||||||
"70B": 8,
|
"70B": 8,
|
||||||
"70Bf": 8,
|
"70Bf": 8,
|
||||||
|
"405B": 8,
|
||||||
|
"405B-MP16": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048}
|
||||||
|
|
||||||
|
|
||||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||||
@ -107,9 +112,10 @@ def write_model(
|
|||||||
input_base_path,
|
input_base_path,
|
||||||
model_size=None,
|
model_size=None,
|
||||||
safe_serialization=True,
|
safe_serialization=True,
|
||||||
llama_version=1,
|
llama_version="1",
|
||||||
vocab_size=None,
|
vocab_size=None,
|
||||||
num_shards=None,
|
num_shards=None,
|
||||||
|
instruct=False,
|
||||||
):
|
):
|
||||||
os.makedirs(model_path, exist_ok=True)
|
os.makedirs(model_path, exist_ok=True)
|
||||||
tmp_model_path = os.path.join(model_path, "tmp")
|
tmp_model_path = os.path.join(model_path, "tmp")
|
||||||
@ -125,18 +131,11 @@ def write_model(
|
|||||||
dims_per_head = dim // n_heads
|
dims_per_head = dim // n_heads
|
||||||
base = params.get("rope_theta", 10000.0)
|
base = params.get("rope_theta", 10000.0)
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||||
if base > 10000.0 and llama_version != 3:
|
if base > 10000.0 and float(llama_version) < 3:
|
||||||
max_position_embeddings = 16384
|
max_position_embeddings = 16384
|
||||||
else:
|
else:
|
||||||
# Depending on the Llama version, the default max_position_embeddings has different values.
|
max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version]
|
||||||
if llama_version == 1:
|
|
||||||
max_position_embeddings = 2048
|
|
||||||
elif llama_version == 2:
|
|
||||||
max_position_embeddings = 4096
|
|
||||||
elif llama_version == 3:
|
|
||||||
max_position_embeddings = 8192
|
|
||||||
|
|
||||||
vocab_size = vocab_size if vocab_size is not None else 32000
|
|
||||||
if params.get("n_kv_heads", None) is not None:
|
if params.get("n_kv_heads", None) is not None:
|
||||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||||
num_key_value_heads_per_shard = num_key_value_heads // num_shards
|
num_key_value_heads_per_shard = num_key_value_heads // num_shards
|
||||||
@ -144,8 +143,7 @@ def write_model(
|
|||||||
else: # compatibility with other checkpoints
|
else: # compatibility with other checkpoints
|
||||||
num_key_value_heads = n_heads
|
num_key_value_heads = n_heads
|
||||||
num_key_value_heads_per_shard = n_heads_per_shard
|
num_key_value_heads_per_shard = n_heads_per_shard
|
||||||
key_value_dim = dims_per_head * num_key_value_heads
|
key_value_dim = dim
|
||||||
print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim)
|
|
||||||
|
|
||||||
# permute for sliced rotary
|
# permute for sliced rotary
|
||||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||||
@ -159,11 +157,9 @@ def write_model(
|
|||||||
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
||||||
else:
|
else:
|
||||||
# Sharded
|
# Sharded
|
||||||
loaded = [
|
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
|
||||||
torch.load(os.path.join(input_base_path, file), map_location="cpu")
|
print("Loading in order:", checkpoint_list)
|
||||||
for file in sorted(os.listdir(input_base_path))
|
loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list]
|
||||||
if file.endswith(".pth")
|
|
||||||
]
|
|
||||||
param_count = 0
|
param_count = 0
|
||||||
index_dict = {"weight_map": {}}
|
index_dict = {"weight_map": {}}
|
||||||
for layer_i in range(n_layers):
|
for layer_i in range(n_layers):
|
||||||
@ -263,7 +259,7 @@ def write_model(
|
|||||||
"lm_head.weight": loaded["output.weight"],
|
"lm_head.weight": loaded["output.weight"],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
concat_dim = 0 if llama_version == 3 else 1
|
concat_dim = 0 if llama_version in ["3", "3.1"] else 1
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"model.norm.weight": loaded[0]["norm.weight"],
|
"model.norm.weight": loaded[0]["norm.weight"],
|
||||||
"model.embed_tokens.weight": torch.cat(
|
"model.embed_tokens.weight": torch.cat(
|
||||||
@ -282,6 +278,18 @@ def write_model(
|
|||||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||||
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
||||||
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
||||||
|
|
||||||
|
if llama_version in ["3", "3.1"]:
|
||||||
|
bos_token_id = 128000
|
||||||
|
|
||||||
|
if instruct:
|
||||||
|
eos_token_id = [128001, 128008, 128009]
|
||||||
|
else:
|
||||||
|
eos_token_id = 128001
|
||||||
|
else:
|
||||||
|
bos_token_id = 1
|
||||||
|
eos_token_id = 2
|
||||||
|
|
||||||
config = LlamaConfig(
|
config = LlamaConfig(
|
||||||
hidden_size=dim,
|
hidden_size=dim,
|
||||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||||
@ -292,11 +300,21 @@ def write_model(
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
rope_theta=base,
|
rope_theta=base,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
bos_token_id=128000 if llama_version == 3 else 1,
|
bos_token_id=bos_token_id,
|
||||||
eos_token_id=128001 if llama_version == 3 else 2,
|
eos_token_id=eos_token_id,
|
||||||
)
|
)
|
||||||
config.save_pretrained(tmp_model_path)
|
config.save_pretrained(tmp_model_path)
|
||||||
|
|
||||||
|
if instruct:
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.9,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
)
|
||||||
|
generation_config.save_pretrained(tmp_model_path)
|
||||||
|
|
||||||
# Make space so we can load the model properly now.
|
# Make space so we can load the model properly now.
|
||||||
del state_dict
|
del state_dict
|
||||||
del loaded
|
del loaded
|
||||||
@ -313,7 +331,7 @@ def write_model(
|
|||||||
|
|
||||||
|
|
||||||
class Llama3Converter(TikTokenConverter):
|
class Llama3Converter(TikTokenConverter):
|
||||||
def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs):
|
def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs):
|
||||||
super().__init__(vocab_file, **kwargs)
|
super().__init__(vocab_file, **kwargs)
|
||||||
tokenizer = self.converted()
|
tokenizer = self.converted()
|
||||||
chat_template = (
|
chat_template = (
|
||||||
@ -327,34 +345,24 @@ class Llama3Converter(TikTokenConverter):
|
|||||||
"{% endfor %}"
|
"{% endfor %}"
|
||||||
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
||||||
)
|
)
|
||||||
num_reserved_special_tokens = 256
|
|
||||||
special_tokens = [
|
|
||||||
"<|begin_of_text|>",
|
|
||||||
"<|end_of_text|>",
|
|
||||||
"<|reserved_special_token_0|>",
|
|
||||||
"<|reserved_special_token_1|>",
|
|
||||||
"<|reserved_special_token_2|>",
|
|
||||||
"<|reserved_special_token_3|>",
|
|
||||||
"<|start_header_id|>",
|
|
||||||
"<|end_header_id|>",
|
|
||||||
"<|reserved_special_token_4|>",
|
|
||||||
"<|eot_id|>", # end of turn
|
|
||||||
] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)]
|
|
||||||
tokenizer.add_special_tokens(special_tokens)
|
tokenizer.add_special_tokens(special_tokens)
|
||||||
|
|
||||||
self.tokenizer = PreTrainedTokenizerFast(
|
self.tokenizer = PreTrainedTokenizerFast(
|
||||||
tokenizer_object=tokenizer,
|
tokenizer_object=tokenizer,
|
||||||
bos_token="<|begin_of_text|>",
|
bos_token="<|begin_of_text|>",
|
||||||
eos_token="<|end_of_text|>",
|
eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
|
||||||
chat_template=chat_template,
|
chat_template=chat_template if instruct else None,
|
||||||
model_input_names=["input_ids", "attention_mask"],
|
model_input_names=["input_ids", "attention_mask"],
|
||||||
|
model_max_length=model_max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2):
|
def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False):
|
||||||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
||||||
if llama_version == 3:
|
if llama_version in ["3", "3.1"]:
|
||||||
tokenizer = Llama3Converter(input_tokenizer_path).tokenizer
|
tokenizer = Llama3Converter(
|
||||||
|
input_tokenizer_path, special_tokens, instruct, model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version]
|
||||||
|
).tokenizer
|
||||||
else:
|
else:
|
||||||
tokenizer = tokenizer_class(input_tokenizer_path)
|
tokenizer = tokenizer_class(input_tokenizer_path)
|
||||||
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
||||||
@ -362,6 +370,37 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_LLAMA_SPECIAL_TOKENS = {
|
||||||
|
"3": [
|
||||||
|
"<|begin_of_text|>",
|
||||||
|
"<|end_of_text|>",
|
||||||
|
"<|reserved_special_token_0|>",
|
||||||
|
"<|reserved_special_token_1|>",
|
||||||
|
"<|reserved_special_token_2|>",
|
||||||
|
"<|reserved_special_token_3|>",
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"<|reserved_special_token_4|>",
|
||||||
|
"<|eot_id|>", # end of turn
|
||||||
|
]
|
||||||
|
+ [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)],
|
||||||
|
"3.1": [
|
||||||
|
"<|begin_of_text|>",
|
||||||
|
"<|end_of_text|>",
|
||||||
|
"<|reserved_special_token_0|>",
|
||||||
|
"<|reserved_special_token_1|>",
|
||||||
|
"<|finetune_right_pad_id|>",
|
||||||
|
"<|reserved_special_token_2|>",
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"<|eom_id|>", # end of message
|
||||||
|
"<|eot_id|>", # end of turn
|
||||||
|
"<|python_tag|>",
|
||||||
|
]
|
||||||
|
+ [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -383,9 +422,9 @@ def main():
|
|||||||
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llama_version",
|
"--llama_version",
|
||||||
choices=[1, 2, 3],
|
choices=["1", "2", "3", "3.1"],
|
||||||
default=1,
|
default="1",
|
||||||
type=int,
|
type=str,
|
||||||
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
|
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -394,11 +433,34 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
|
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--special_tokens",
|
||||||
|
default=None,
|
||||||
|
type=List[str],
|
||||||
|
help="The list of special tokens that should be added to the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--instruct",
|
||||||
|
default=False,
|
||||||
|
type=bool,
|
||||||
|
help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.model_size is None and args.num_shards is None:
|
if args.model_size is None and args.num_shards is None:
|
||||||
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
|
raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`")
|
||||||
|
if args.special_tokens is None:
|
||||||
|
args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS[str(args.llama_version)]
|
||||||
|
|
||||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||||
vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version))
|
vocab_size = len(
|
||||||
|
write_tokenizer(
|
||||||
|
args.output_dir,
|
||||||
|
spm_path,
|
||||||
|
llama_version=args.llama_version,
|
||||||
|
special_tokens=args.special_tokens,
|
||||||
|
instruct=args.instruct,
|
||||||
|
)
|
||||||
|
)
|
||||||
if args.model_size != "tokenizer_only":
|
if args.model_size != "tokenizer_only":
|
||||||
write_model(
|
write_model(
|
||||||
model_path=args.output_dir,
|
model_path=args.output_dir,
|
||||||
@ -408,6 +470,7 @@ def main():
|
|||||||
llama_version=args.llama_version,
|
llama_version=args.llama_version,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
num_shards=args.num_shards,
|
num_shards=args.num_shards,
|
||||||
|
instruct=args.instruct,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user