Compare commits

...

8 Commits

Author SHA1 Message Date
75f15f39a0 Release: v4.41.1 2024-05-22 13:40:40 -04:00
8282db5cc9 Paligemma causal attention mask (#30967)
* PaliGemma working causal attention

* Formatting

* Style

* Docstrings + remove commented code

* Update docstring for PaliGemma Config

* PaliGemma - add separator ind to model/labels

* Refactor + docstring paligemma processor method

* Style

* return token type ids when tokenizing labels

* use token type ids when building causal mask

* add token type ids to tester

* remove separator from config

* fix style

* don't ignore separator

* add processor documentation

* simplify tokenization

* fix causal mask

* style

* fix label propagation, revert suffix naming

* fix style

* fix labels tokenization

* [run-slow]paligemma

* add eos if suffixes are present

* [run-slow]paligemma

* [run-slow]paligemma

* add misssing tokens to fast version

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix style

* [run-slow]paligemma

---------

Co-authored-by: Peter Robicheaux <peter@roboflow.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-05-22 13:39:52 -04:00
e5b788ade3 Revert "feat: Upgrade Weights & Biases callback (#30135)"
This reverts commit 4ab7a28216211571fdddba414d4edd8426ab6489.
2024-05-22 12:39:27 -04:00
9d054596e7 Generation: get special tokens from model config (#30899)
* fix

* let's do this way?

* codestyle

* update

* add tests
2024-05-22 12:37:27 -04:00
e5d174f12a PaliGemma - fix processor with no input text (#30916)
Update processing_paligemma.py
2024-05-22 12:37:15 -04:00
04141855bd legacy to init the slow tokenizer when converting from slow was wrong (#30972) 2024-05-22 12:37:07 -04:00
6d2439a126 tokenizer_class = "AutoTokenizer" Llava Family (#30912)
propagate changes to more models
2024-05-22 12:36:58 -04:00
4c6c45ba13 Release: v4.41.0 2024-05-17 11:11:44 -04:00
60 changed files with 226 additions and 194 deletions

View File

@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
Array = Any Array = Any
Dataset = datasets.arrow_dataset.Dataset Dataset = datasets.arrow_dataset.Dataset

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk. # Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
Array = Any Array = Any
Dataset = datasets.arrow_dataset.Dataset Dataset = datasets.arrow_dataset.Dataset

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type` # You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,8 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.40.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.40.0.dev0") check_min_version("4.41.0")
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version( require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies # region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
task_to_keys = { task_to_keys = {
"cola": ("sentence", None), "cola": ("sentence", None),

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants # region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.41.0.dev0") check_min_version("4.41.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -426,7 +426,7 @@ install_requires = [
setup( setup(
name="transformers", name="transformers",
version="4.41.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="4.41.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co", author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends). # in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.41.0.dev0" __version__ = "4.41.1"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING

View File

@ -1354,6 +1354,23 @@ class GenerationMixin:
self._static_cache.reset() # reset the cache for a new generation self._static_cache.reset() # reset the cache for a new generation
return self._static_cache return self._static_cache
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
else:
return
def _prepare_special_tokens( def _prepare_special_tokens(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
@ -1378,11 +1395,16 @@ class GenerationMixin:
return token return token
return torch.tensor(token, device=device, dtype=torch.long) return torch.tensor(token, device=device, dtype=torch.long)
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
if self.config.is_encoder_decoder:
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
generation_config.decoder_start_token_id, generation_config.bos_token_id
)
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device) bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device) eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device) pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device) decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists). # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0: if eos_token_id is not None and eos_token_id.ndim == 0:

View File

@ -31,17 +31,8 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np import numpy as np
import packaging.version import packaging.version
from .. import PreTrainedModel, TFPreTrainedModel
from .. import __version__ as version from .. import __version__ as version
from ..utils import ( from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
PushToHubMixin,
flatten_dict,
is_datasets_available,
is_pandas_available,
is_tf_available,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -78,7 +69,6 @@ if TYPE_CHECKING and _has_neptune:
except importlib.metadata.PackageNotFoundError: except importlib.metadata.PackageNotFoundError:
_has_neptune = False _has_neptune = False
from .. import modelcard # noqa: E402
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402 from ..training_args import ParallelMode # noqa: E402
@ -673,22 +663,6 @@ class TensorBoardCallback(TrainerCallback):
self.tb_writer = None self.tb_writer = None
def save_model_architecture_to_file(model: Any, output_dir: str):
with open(f"{output_dir}/model_architecture.txt", "w+") as f:
if isinstance(model, PreTrainedModel):
print(model, file=f)
elif is_tf_available() and isinstance(model, TFPreTrainedModel):
def print_to_file(s):
print(s, file=f)
model.summary(print_fn=print_to_file)
elif is_torch_available() and (
isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
):
print(model, file=f)
class WandbCallback(TrainerCallback): class WandbCallback(TrainerCallback):
""" """
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
@ -754,9 +728,6 @@ class WandbCallback(TrainerCallback):
if hasattr(model, "config") and model.config is not None: if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict() model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict} combined_dict = {**model_config, **combined_dict}
if hasattr(model, "peft_config") and model.peft_config is not None:
peft_config = model.peft_config
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
trial_name = state.trial_name trial_name = state.trial_name
init_args = {} init_args = {}
if trial_name is not None: if trial_name is not None:
@ -790,46 +761,6 @@ class WandbCallback(TrainerCallback):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer") self._wandb.run._label(code="transformers_trainer")
# add number of model parameters to wandb config
try:
self._wandb.config["model/num_parameters"] = model.num_parameters()
except AttributeError:
logger.info("Could not log the number of model parameters in Weights & Biases.")
# log the initial model and architecture to an artifact
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
model.save_pretrained(temp_dir)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None: if self._wandb is None:
return return
@ -860,25 +791,20 @@ class WandbCallback(TrainerCallback):
else { else {
f"eval/{args.metric_for_best_model}": state.best_metric, f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos, "train/total_floss": state.total_flos,
"model/num_parameters": self._wandb.config.get("model/num_parameters"),
} }
) )
metadata["final_model"] = True
logger.info("Logging model artifacts. ...") logger.info("Logging model artifacts. ...")
model_name = ( model_name = (
f"model-{self._wandb.run.id}" f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir) if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}" else f"model-{self._wandb.run.name}"
) )
# add the model architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata) artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"): for f in Path(temp_dir).glob("*"):
if f.is_file(): if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa: with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes()) fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact, aliases=["final_model"]) self._wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs): def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [ single_value_scalars = [
@ -908,30 +834,18 @@ class WandbCallback(TrainerCallback):
for k, v in dict(self._wandb.summary).items() for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_") if isinstance(v, numbers.Number) and not k.startswith("_")
} }
checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
ckpt_dir = f"checkpoint-{state.global_step}" ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir) artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...") logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = ( checkpoint_name = (
f"model-{self._wandb.run.id}" f"checkpoint-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir) if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}" else f"checkpoint-{self._wandb.run.name}"
) )
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata) artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path) artifact.add_dir(artifact_path)
self._wandb.log_artifact( self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
)
def on_predict(self, args, state, control, metrics, **kwargs):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, **kwargs)
if state.is_world_process_zero:
metrics = rewrite_logs(metrics)
self._wandb.log(metrics)
class CometCallback(TrainerCallback): class CometCallback(TrainerCallback):

View File

@ -151,9 +151,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
self.legacy = legacy self.legacy = legacy
if add_prefix_space is not None: if add_prefix_space is not None:
logger.warning_once(
"You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
)
kwargs["from_slow"] = True kwargs["from_slow"] = True
super().__init__( super().__init__(
@ -166,6 +163,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
add_eos_token=add_eos_token, add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt, use_default_system_prompt=use_default_system_prompt,
legacy=legacy,
**kwargs, **kwargs,
) )
self._add_bos_token = add_bos_token self._add_bos_token = add_bos_token

View File

@ -40,8 +40,8 @@ class LlavaNextProcessor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
image_processor_class = "LlavaNextImageProcessor" image_processor_class = "AutoImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None): def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)

View File

@ -282,9 +282,14 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
self.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings
return model_embeds return model_embeds
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
):
_, _, embed_dim = image_features.shape _, _, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
dtype, device = inputs_embeds.dtype, inputs_embeds.device
min_dtype = torch.finfo(dtype).min
scaled_image_features = image_features / (self.config.hidden_size**0.5) scaled_image_features = image_features / (self.config.hidden_size**0.5)
final_embedding = torch.zeros( final_embedding = torch.zeros(
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
@ -305,24 +310,43 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
) )
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
if attention_mask is not None:
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
else:
position_ids = None
final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1) if token_type_ids is not None and labels is not None:
final_attention_mask_4d = final_attention_mask_4d.float().expand( # we are training thus we need to create a full mask on the image + prefix but causal on suffix
-1, self.config.text_config.num_key_value_heads, -1, -1 target_length = cache_position[-1] + 1
) causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
# unmask the prefill
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :] == 0, 0
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
if labels is not None:
final_labels = torch.full( final_labels = torch.full(
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device (batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
) )
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels) final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
else: else:
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
final_labels = None final_labels = None
return final_embedding, final_attention_mask_4d, final_labels, position_ids return final_embedding, causal_mask, final_labels, position_ids
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@ -333,6 +357,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -396,8 +421,10 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
selected_image_feature = image_outputs.last_hidden_state selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature) image_features = self.multi_modal_projector(selected_image_feature)
if cache_position is None:
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
) )
else: else:
@ -486,6 +513,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
cache_position=None, cache_position=None,
pixel_values=None, pixel_values=None,
attention_mask=None, attention_mask=None,
token_type_ids=None,
**kwargs, **kwargs,
): ):
past_length = 0 past_length = 0
@ -544,6 +572,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask, "attention_mask": attention_mask,
"pixel_values": pixel_values, "pixel_values": pixel_values,
"token_type_ids": token_type_ids,
} }
) )
return model_inputs return model_inputs

File diff suppressed because one or more lines are too long

View File

@ -42,7 +42,7 @@ class VideoLlavaProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
image_processor_class = "VideoLlavaImageProcessor" image_processor_class = "VideoLlavaImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None): def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)

View File

@ -65,6 +65,7 @@ if is_torch_available():
GenerateBeamEncoderDecoderOutput, GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput, GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput, GenerateEncoderDecoderOutput,
GenerationConfig,
GreedySearchDecoderOnlyOutput, GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput,
LogitsProcessorList, LogitsProcessorList,
@ -2478,6 +2479,35 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist()) self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())
def test_decoder_start_id_from_config(self):
# Refer to: (#30899)
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
# we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
# If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
bart_model.generation_config.decoder_start_token_id = None
bart_model.generation_config.bos_token_id = None
outputs_with_user_id = bart_model.generate(
input_ids,
generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id),
)
self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist())
with self.assertRaises(ValueError):
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
def test_contrastive_search_batched(self): def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search # PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs) # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)

View File

@ -163,6 +163,8 @@ class PaliGemmaVisionText2TextModelTester:
"pixel_values": pixel_values, "pixel_values": pixel_values,
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"labels": input_ids,
"token_type_ids": torch.zeros_like(input_ids),
} }
return config, inputs_dict return config, inputs_dict