mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 17:48:57 +08:00
Compare commits
14 Commits
serve-quan
...
v4.46.2
Author | SHA1 | Date | |
---|---|---|---|
ccbd57a8b6 | |||
e66224b544 | |||
8c62a92b3c | |||
5b36cdabf5 | |||
f784d95c0f | |||
7da0eefc27 | |||
bc598c00db | |||
94ed13c1de | |||
72c716de92 | |||
97bb9299c4 | |||
565f0e97c2 | |||
dcfe3c7e61 | |||
c2820c9491 | |||
b298161146 |
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.46.0.dev0")
|
||||
check_min_version("4.46.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -435,7 +435,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.46.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.46.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.46.0.dev0"
|
||||
__version__ = "4.46.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -28,7 +28,7 @@ import tempfile
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, partial, wraps
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from zipfile import is_zipfile
|
||||
@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
# Not all the attributes of a module are Parameters/Tensor
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
|
||||
old_param = getattr(old_param, split, None)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
if old_param is not None:
|
||||
if dtype is None:
|
||||
param = param.to(old_param.dtype)
|
||||
@ -5013,7 +5014,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return self.hf_quantizer.is_trainable
|
||||
|
||||
@property
|
||||
@lru_cache
|
||||
def loss_function(self):
|
||||
if getattr(self.config, "loss_type", None) is not None:
|
||||
loss_type = self.config.loss_type
|
||||
|
@ -1288,7 +1288,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
|
||||
n_image_features = image_tokens.shape[0]
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
if n_image_tokens_in_text != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
|
||||
|
@ -467,6 +467,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
image_features = None
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
@ -474,69 +475,68 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
# prefill stage vs decoding stage (legacy behavior copied)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
# prefill stage vs decoding stage (legacy behavior copied)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
elif image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
@ -597,12 +597,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
||||
legacy_processing = (
|
||||
input_ids is not None
|
||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@ -613,7 +607,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
@ -846,6 +846,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
image_features = None
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
@ -861,74 +862,73 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
elif image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
@ -990,11 +990,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
legacy_processing = (
|
||||
input_ids is not None
|
||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@ -1007,7 +1002,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
|
||||
|
@ -1020,6 +1020,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
if image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -1110,17 +1111,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
):
|
||||
# Overwritten -- extra custom processing
|
||||
|
||||
if input_ids is not None:
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@ -1133,7 +1123,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
|
@ -533,6 +533,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
if image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -623,17 +624,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
):
|
||||
# Overwritten -- extra custom processing
|
||||
|
||||
if input_ids is not None:
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@ -646,7 +636,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
|
@ -679,6 +679,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
)
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -704,6 +705,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
|
@ -1156,7 +1156,7 @@ class MimiTransformerModel(nn.Module):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -961,7 +961,7 @@ class MistralModel(MistralPreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -1174,7 +1174,7 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -1385,7 +1385,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
@ -1689,7 +1689,7 @@ class MoshiModel(MoshiPreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -1139,7 +1139,7 @@ class Phi3Model(Phi3PreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -1305,7 +1305,7 @@ class PhimoeModel(PhimoePreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -762,11 +762,14 @@ class Pix2StructTextAttention(nn.Module):
|
||||
return relative_buckets
|
||||
|
||||
# Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
|
||||
def compute_bias(self, query_length, key_length, device=None):
|
||||
def compute_bias(self, query_length, key_length, device=None, cache_position=None):
|
||||
"""Compute binned relative position bias"""
|
||||
if device is None:
|
||||
device = self.relative_attention_bias.weight.device
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
if cache_position is None:
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
else:
|
||||
context_position = cache_position[:, None].to(device)
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
@ -779,6 +782,7 @@ class Pix2StructTextAttention(nn.Module):
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
# Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -796,61 +800,66 @@ class Pix2StructTextAttention(nn.Module):
|
||||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
||||
"""
|
||||
# Input is (batch_size, seq_length, dim)
|
||||
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length)
|
||||
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
query_states = self.query(hidden_states).contiguous()
|
||||
query_states = self.query(hidden_states)
|
||||
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# get key/value states
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.key(current_states).contiguous()
|
||||
value_states = self.value(current_states).contiguous()
|
||||
key_states = self.key(current_states)
|
||||
value_states = self.value(current_states)
|
||||
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
# compute scores
|
||||
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
||||
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
||||
|
||||
if position_bias is None:
|
||||
real_seq_length = cache_position[-1] + 1 if query_length is None else query_length
|
||||
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
||||
key_length = key_states.shape[-2]
|
||||
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
|
||||
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
|
||||
if not self.has_relative_attention_bias:
|
||||
position_bias = torch.zeros(
|
||||
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
|
||||
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
position_bias.requires_grad = True
|
||||
else:
|
||||
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
||||
position_bias = self.compute_bias(
|
||||
real_seq_length, key_length, device=scores.device, cache_position=cache_position
|
||||
)
|
||||
position_bias = position_bias[:, :, -seq_length:, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
causal_mask = mask[:, :, :, : key_states.shape[-2]]
|
||||
position_bias = position_bias + causal_mask
|
||||
|
||||
if self.pruned_heads:
|
||||
mask = torch.ones(position_bias.shape[1])
|
||||
@ -860,10 +869,9 @@ class Pix2StructTextAttention(nn.Module):
|
||||
position_bias_masked = position_bias
|
||||
|
||||
scores += position_bias_masked
|
||||
# (batch_size, n_heads, seq_length, key_length)
|
||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
||||
|
||||
# (batch_size, n_heads, seq_length, key_length)
|
||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
@ -871,12 +879,12 @@ class Pix2StructTextAttention(nn.Module):
|
||||
attn_weights = attn_weights * layer_head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
# (batch_size, seq_length, dim)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
|
||||
attn_output = self.output(attn_output)
|
||||
|
||||
outputs = (attn_output,) + (past_key_value,) + (position_bias,)
|
||||
outputs = (attn_output, past_key_value, position_bias)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (attn_weights,)
|
||||
@ -969,7 +977,10 @@ class Pix2StructTextBlock(nn.Module):
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
|
||||
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
|
||||
config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
self.mlp = Pix2StructTextLayerFF(config)
|
||||
|
||||
@ -1019,7 +1030,6 @@ class Pix2StructTextBlock(nn.Module):
|
||||
query_length=cache_position[-1] + 1,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states, past_key_value = cross_attention_outputs[:2]
|
||||
|
||||
|
@ -52,6 +52,8 @@ class PixtralVisionConfig(PretrainedConfig):
|
||||
Dropout probability for the attention layers.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
Example:
|
||||
|
||||
@ -82,6 +84,7 @@ class PixtralVisionConfig(PretrainedConfig):
|
||||
hidden_act="gelu",
|
||||
attention_dropout=0.0,
|
||||
rope_theta=10000.0,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -97,3 +100,4 @@ class PixtralVisionConfig(PretrainedConfig):
|
||||
self.hidden_act = hidden_act
|
||||
self.rope_theta = rope_theta
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.initializer_range = initializer_range
|
||||
|
@ -407,7 +407,7 @@ class PixtralPreTrainedModel(PreTrainedModel):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
else self.config.initializer_range
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
|
@ -206,14 +206,15 @@ class PixtralProcessor(ProcessorMixin):
|
||||
if is_image_or_image_url(images):
|
||||
images = [[images]]
|
||||
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||
images = [images]
|
||||
elif (
|
||||
not isinstance(images, list)
|
||||
and not isinstance(images[0], list)
|
||||
and not is_image_or_image_url(images[0][0])
|
||||
):
|
||||
if isinstance(text, list):
|
||||
images = [[im] for im in images]
|
||||
else:
|
||||
images = [images]
|
||||
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
||||
"Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
|
||||
)
|
||||
images = [[load_image(im) for im in sample] for sample in images]
|
||||
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
|
||||
|
@ -1059,7 +1059,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -1239,7 +1239,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -1321,7 +1321,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
@ -1503,13 +1503,14 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
mrope_position_deltas = []
|
||||
if image_grid_thw is not None or video_grid_thw is not None:
|
||||
total_input_ids = input_ids
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(total_input_ids)
|
||||
position_ids = torch.ones(
|
||||
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
image_index, video_index = 0, 0
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
if attention_mask is not None:
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
image_nums, video_nums = 0, 0
|
||||
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
|
@ -1033,7 +1033,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
@ -622,8 +622,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if pixel_values_images is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -638,8 +638,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
|
||||
n_video_features = video_features.shape[1]
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0] * video_features.shape[1]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
@ -714,17 +714,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
if input_ids is not None:
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@ -735,7 +724,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values_images"] = pixel_values_images
|
||||
|
@ -461,72 +461,71 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
image_features = None
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in VipLLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in VipLLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
# prefill stage vs decoding stage (legacy behavior copied)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
# prefill stage vs decoding stage (legacy behavior copied)
|
||||
if input_ids.shape[1] != 1:
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# in the case one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# in the case one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
elif image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
@ -585,12 +584,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
||||
legacy_processing = (
|
||||
input_ids is not None
|
||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@ -601,7 +594,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
@ -2,7 +2,7 @@
|
||||
import datetime
|
||||
import platform
|
||||
import subprocess
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -51,7 +51,7 @@ def ffmpeg_microphone(
|
||||
chunk_length_s: float,
|
||||
format_for_conversion: str = "f32le",
|
||||
ffmpeg_input_device: Optional[str] = None,
|
||||
ffmpeg_additional_args: Optional[list[str]] = None,
|
||||
ffmpeg_additional_args: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Helper function to read audio from a microphone using ffmpeg. The default input device will be used unless another
|
||||
@ -138,7 +138,7 @@ def ffmpeg_microphone_live(
|
||||
stride_length_s: Optional[Union[Tuple[float, float], float]] = None,
|
||||
format_for_conversion: str = "f32le",
|
||||
ffmpeg_input_device: Optional[str] = None,
|
||||
ffmpeg_additional_args: Optional[list[str]] = None,
|
||||
ffmpeg_additional_args: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Helper function to read audio from a microphone using ffmpeg. This will output `partial` overlapping chunks starting
|
||||
|
@ -314,7 +314,7 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
||||
|
||||
Args:
|
||||
elements (`torch.Tensor`): Input elements
|
||||
test_elements (`torch.Tensor`): The elements to check against.
|
||||
test_elements (`torch.Tensor` or `int`): The elements to check against.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
|
||||
@ -322,6 +322,9 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
||||
"""
|
||||
|
||||
if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||
test_elements = torch.tensor(test_elements)
|
||||
if test_elements.ndim == 0:
|
||||
test_elements = test_elements.unsqueeze(0)
|
||||
return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
|
||||
else:
|
||||
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
|
||||
|
@ -233,7 +233,6 @@ if is_accelerate_available():
|
||||
from accelerate.utils import (
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
GradientAccumulationPlugin,
|
||||
load_fsdp_model,
|
||||
load_fsdp_optimizer,
|
||||
save_fsdp_model,
|
||||
@ -589,8 +588,10 @@ class Trainer:
|
||||
if not _is_peft_model(unwrapped_model)
|
||||
else unwrapped_model.get_base_model().forward
|
||||
)
|
||||
|
||||
self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters
|
||||
forward_params = inspect.signature(model_forward).parameters
|
||||
self.model_accepts_loss_kwargs = (
|
||||
"loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
|
||||
)
|
||||
|
||||
self.neftune_noise_alpha = args.neftune_noise_alpha
|
||||
|
||||
@ -2424,7 +2425,7 @@ class Trainer:
|
||||
update_step += 1
|
||||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
||||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
|
||||
for inputs in batch_samples:
|
||||
for i, inputs in enumerate(batch_samples):
|
||||
step += 1
|
||||
total_batched_samples += 1
|
||||
is_last_step_and_steps_less_than_grad_acc = (
|
||||
@ -2470,7 +2471,13 @@ class Trainer:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
||||
|
||||
with self.accelerator.accumulate(model):
|
||||
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i == len(batch_samples) - 1
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
if (
|
||||
@ -3602,10 +3609,11 @@ class Trainer:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss *= self.args.gradient_accumulation_steps
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
# Finally we need to normalize the loss for reporting
|
||||
if num_items_in_batch is None:
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
return loss.detach()
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
"""
|
||||
@ -3650,6 +3658,9 @@ class Trainer:
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def is_local_process_zero(self) -> bool:
|
||||
@ -4894,24 +4905,21 @@ class Trainer:
|
||||
self.repo.git_push()
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
|
||||
grad_acc_kwargs = {}
|
||||
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
|
||||
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
|
||||
|
||||
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
|
||||
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
|
||||
# raise because we do not know which setting is intended.
|
||||
raise ValueError(
|
||||
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
|
||||
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
|
||||
)
|
||||
elif "num_steps" not in grad_acc_kwargs:
|
||||
# take the gradient_accumulation_steps setting from TrainingArguments.
|
||||
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
|
||||
|
||||
grad_acc_kwargs["sync_with_dataloader"] = False
|
||||
|
||||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
||||
if "num_steps" in grad_acc_kwargs:
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
# raise because we do not know which setting is intended.
|
||||
raise ValueError(
|
||||
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
|
||||
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
|
||||
)
|
||||
else:
|
||||
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
|
||||
|
||||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
|
||||
@ -4942,7 +4950,6 @@ class Trainer:
|
||||
|
||||
args = {
|
||||
"deepspeed_plugin": self.args.deepspeed_plugin,
|
||||
"gradient_accumulation_plugin": gradient_accumulation_plugin,
|
||||
}
|
||||
if is_accelerate_available("0.28.0"):
|
||||
args["dataloader_config"] = dataloader_config
|
||||
@ -5038,12 +5045,18 @@ class Trainer:
|
||||
batch_samples += [next(epoch_iterator)]
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Keep default behavior the same
|
||||
if not self.model_accepts_loss_kwargs:
|
||||
return batch_samples, None
|
||||
|
||||
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
||||
# For now we don't support object detection
|
||||
try:
|
||||
num_items_in_batch = sum(
|
||||
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
|
||||
)
|
||||
except TypeError:
|
||||
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
if self.args.average_tokens_across_devices:
|
||||
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
|
||||
return batch_samples, num_items_in_batch
|
||||
|
@ -1530,6 +1530,15 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
average_tokens_across_devices: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
|
||||
"synchronize num_tokens_in_batch for precise loss calculation. Reference: "
|
||||
"https://github.com/huggingface/transformers/issues/34242"
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||
for field in _VALID_DICT_FIELDS:
|
||||
@ -1763,6 +1772,19 @@ class TrainingArguments:
|
||||
if self.framework == "pt" and is_torch_available():
|
||||
self.device
|
||||
|
||||
# Disable average tokens when using single device
|
||||
if self.average_tokens_across_devices:
|
||||
try:
|
||||
if self.world_size == 1:
|
||||
logger.warning(
|
||||
"average_tokens_across_devices is set to True but it is invalid when world size is"
|
||||
"1. Turn it to False automatically."
|
||||
)
|
||||
self.average_tokens_across_devices = False
|
||||
except ImportError as e:
|
||||
logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
|
||||
self.average_tokens_across_devices = False
|
||||
|
||||
if self.torchdynamo is not None:
|
||||
warnings.warn(
|
||||
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||
|
@ -1416,7 +1416,7 @@ class HFTracer(Tracer):
|
||||
your custom tracer.
|
||||
"""
|
||||
attribute = HFAttribute(obj, "keys")()
|
||||
if obj.node.target == "**kwargs":
|
||||
if obj.node.target.startswith("**"):
|
||||
return attribute._metadata
|
||||
return attribute
|
||||
|
||||
|
@ -304,7 +304,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="PR #34283 made changes to the forward function.")
|
||||
def test_torch_fx_output_loss(self):
|
||||
super().test_torch_fx_output_loss()
|
||||
|
||||
|
@ -235,6 +235,35 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -283,6 +283,38 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
image_sizes = input_dict["image_sizes"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -303,6 +303,38 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
image_sizes = input_dict["image_sizes"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -356,7 +356,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="PR #34283 made changes to the forward function.")
|
||||
def test_torch_fx_output_loss(self):
|
||||
super().test_torch_fx_output_loss()
|
||||
|
||||
|
@ -356,7 +356,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="PR #34283 made changes to the forward function.")
|
||||
def test_torch_fx_output_loss(self):
|
||||
super().test_torch_fx_output_loss()
|
||||
|
||||
|
@ -236,6 +236,36 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -419,6 +419,7 @@ class Pix2StructModelTester:
|
||||
@require_torch
|
||||
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
|
||||
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
@ -445,6 +446,16 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
),
|
||||
)
|
||||
|
||||
def test_generative_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).eval().to(torch_device)
|
||||
|
||||
output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
|
||||
output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)
|
||||
|
||||
torch.testing.assert_close(output, output_use_cache)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
@ -14,22 +14,16 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Pixtral model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
PixtralVisionConfig,
|
||||
PixtralVisionModel,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@ -43,7 +37,7 @@ else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
pass
|
||||
|
||||
|
||||
class PixtralVisionModelTester:
|
||||
@ -148,6 +142,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PixtralVisionModelTester(self)
|
||||
@ -258,35 +253,3 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_determinism(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class PixtralVisionModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
|
||||
|
||||
prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
inputs = self.processor(prompt, raw_image, return_tensors="pt")
|
||||
|
||||
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
|
||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
@ -171,7 +171,7 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in a url
|
||||
@ -246,6 +246,25 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_returns_full_length_batches(self):
|
||||
# to avoid https://github.com/huggingface/transformers/issues/34204
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = [
|
||||
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||
] * 5
|
||||
processor.tokenizer.pad_token = "</s>"
|
||||
image_inputs = [self.image_0] * 5
|
||||
|
||||
# Make small for checking image token expansion
|
||||
processor.image_processor.size = {"longest_edge": 30}
|
||||
processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 5)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 5)
|
||||
|
||||
# Override as PixtralProcessor needs nested images to work properly with batched inputs
|
||||
@require_vision
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
|
@ -368,7 +368,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="PR #34283 made changes to the forward function.")
|
||||
def test_torch_fx_output_loss(self):
|
||||
super().test_torch_fx_output_loss()
|
||||
|
||||
|
@ -391,7 +391,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="PR #34283 made changes to the forward function.")
|
||||
def test_torch_fx_output_loss(self):
|
||||
super().test_torch_fx_output_loss()
|
||||
|
||||
|
@ -58,7 +58,7 @@ class Qwen2VLVisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
batch_size=3,
|
||||
seq_length=7,
|
||||
num_channels=3,
|
||||
ignore_index=-100,
|
||||
@ -245,6 +245,40 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
patch_size = config.vision_config.patch_size
|
||||
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
|
||||
input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:one_img_length]
|
||||
image_grid_thw = input_dict["image_grid_thw"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -116,9 +116,9 @@ class VideoLlavaVisionText2TextModelTester:
|
||||
self.batch_size = 5
|
||||
self.num_channels = 3
|
||||
self.image_size = 224
|
||||
self.encoder_seq_length = 64
|
||||
self.encoder_seq_length = 246
|
||||
self.num_image_tokens = 25
|
||||
self.num_video_tokens = 26
|
||||
self.num_video_tokens = 26 * self.num_frames
|
||||
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
|
||||
|
||||
def get_config(self):
|
||||
@ -262,7 +262,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
# if we remove some images from inputs leaving only one
|
||||
# image number mismatch error should raise
|
||||
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_video_only_input(self):
|
||||
@ -396,6 +396,35 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values_images"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
@ -217,6 +217,36 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -208,6 +208,26 @@ class TorchAoTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
def test_int8_dynamic_activation_int8_weight_quant(self):
|
||||
"""
|
||||
Simple LLM model testing int8_dynamic_activation_int8_weight
|
||||
"""
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=torch_device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -272,6 +272,19 @@ class RepeatDataset:
|
||||
return {"input_ids": self.x, "labels": self.x}
|
||||
|
||||
|
||||
class SequenceClassificationDataset:
|
||||
def __init__(self, length=64, vocab_size=100, num_labels=5):
|
||||
self.length = length
|
||||
self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
|
||||
self.labels = torch.randint(0, num_labels, (length,)).tolist()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
return {"input_ids": self.sequences[i], "label": self.labels[i]}
|
||||
|
||||
|
||||
class DynamicShapesDataset:
|
||||
def __init__(self, length=64, seed=42, batch_size=8):
|
||||
self.length = length
|
||||
@ -1144,6 +1157,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
|
||||
def test_torch_compile_loss_func_compatibility(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
torch_compile=True,
|
||||
max_steps=1, # compile happens on the first step
|
||||
)
|
||||
trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa
|
||||
trainer.train()
|
||||
|
||||
@require_peft
|
||||
@require_bitsandbytes
|
||||
def test_bnb_compile(self):
|
||||
@ -3676,9 +3706,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_yaml(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@ -3691,8 +3718,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": False,
|
||||
}
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||
json.dump(accelerator_config, f)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
@ -3706,9 +3731,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_dataclass(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@ -3754,10 +3776,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
self.assertEqual(trainer.args.gradient_accumulation_steps, 10)
|
||||
|
||||
def test_accelerator_config_from_partial(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
|
@ -1710,7 +1710,12 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
|
||||
)
|
||||
)
|
||||
# We can match against an tensor of integers
|
||||
# We can match against an 0D tensor
|
||||
random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
# We can match against an 1D tensor (with many items)
|
||||
random_test_tensor = torch.randint(0, 100, (10,))
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
|
Reference in New Issue
Block a user