mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-11-04 20:14:36 +08:00 
			
		
		
		
	Compare commits
	
		
			14 Commits
		
	
	
		
			continuous
			...
			v4.46.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ccbd57a8b6 | |||
| e66224b544 | |||
| 8c62a92b3c | |||
| 5b36cdabf5 | |||
| f784d95c0f | |||
| 7da0eefc27 | |||
| bc598c00db | |||
| 94ed13c1de | |||
| 72c716de92 | |||
| 97bb9299c4 | |||
| 565f0e97c2 | |||
| dcfe3c7e61 | |||
| c2820c9491 | |||
| b298161146 | 
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
Array = Any
 | 
			
		||||
Dataset = datasets.arrow_dataset.Dataset
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
Array = Any
 | 
			
		||||
Dataset = datasets.arrow_dataset.Dataset
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
# You should update this to your particular problem to have better documentation of `model_type`
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version(
 | 
			
		||||
    "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
# region Checking dependencies
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
task_to_keys = {
 | 
			
		||||
    "cola": ("sentence", None),
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
# region Dependencies and constants
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.46.0.dev0")
 | 
			
		||||
check_min_version("4.46.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -435,7 +435,7 @@ install_requires = [
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name="transformers",
 | 
			
		||||
    version="4.46.0.dev0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
 | 
			
		||||
    version="4.46.2",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
 | 
			
		||||
    author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
 | 
			
		||||
    author_email="transformers@huggingface.co",
 | 
			
		||||
    description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@
 | 
			
		||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
 | 
			
		||||
# in the namespace without actually importing anything (and especially none of the backends).
 | 
			
		||||
 | 
			
		||||
__version__ = "4.46.0.dev0"
 | 
			
		||||
__version__ = "4.46.2"
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ import tempfile
 | 
			
		||||
import warnings
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from functools import lru_cache, partial, wraps
 | 
			
		||||
from functools import partial, wraps
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
 | 
			
		||||
from zipfile import is_zipfile
 | 
			
		||||
@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
 | 
			
		||||
        old_param = model
 | 
			
		||||
        splits = param_name.split(".")
 | 
			
		||||
        for split in splits:
 | 
			
		||||
            old_param = getattr(old_param, split)
 | 
			
		||||
            # Not all the attributes of a module are Parameters/Tensor
 | 
			
		||||
            if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
 | 
			
		||||
                old_param = None
 | 
			
		||||
            # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
 | 
			
		||||
            old_param = getattr(old_param, split, None)
 | 
			
		||||
            if old_param is None:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
 | 
			
		||||
            old_param = None
 | 
			
		||||
 | 
			
		||||
        if old_param is not None:
 | 
			
		||||
            if dtype is None:
 | 
			
		||||
                param = param.to(old_param.dtype)
 | 
			
		||||
@ -5013,7 +5014,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
 | 
			
		||||
        return self.hf_quantizer.is_trainable
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @lru_cache
 | 
			
		||||
    def loss_function(self):
 | 
			
		||||
        if getattr(self.config, "loss_type", None) is not None:
 | 
			
		||||
            loss_type = self.config.loss_type
 | 
			
		||||
 | 
			
		||||
@ -1288,7 +1288,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
 | 
			
		||||
        if pixel_values is not None:
 | 
			
		||||
            image_tokens = self.get_image_tokens(pixel_values)
 | 
			
		||||
            n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
 | 
			
		||||
            n_image_features = image_tokens.shape[0]
 | 
			
		||||
            n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
 | 
			
		||||
            if n_image_tokens_in_text != n_image_features:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
 | 
			
		||||
 | 
			
		||||
@ -467,6 +467,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
 | 
			
		||||
                (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
 | 
			
		||||
            ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
 | 
			
		||||
 | 
			
		||||
        image_features = None
 | 
			
		||||
        if pixel_values is not None:
 | 
			
		||||
            image_features = self.get_image_features(
 | 
			
		||||
                pixel_values=pixel_values,
 | 
			
		||||
@ -474,69 +475,68 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
 | 
			
		||||
                vision_feature_select_strategy=vision_feature_select_strategy,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if legacy_processing:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "Expanding inputs for image tokens in LLaVa should be done in processing. "
 | 
			
		||||
                    "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
 | 
			
		||||
                    "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
 | 
			
		||||
                    "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
 | 
			
		||||
        if legacy_processing:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "Expanding inputs for image tokens in LLaVa should be done in processing. "
 | 
			
		||||
                "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
 | 
			
		||||
                "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
 | 
			
		||||
                "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
 | 
			
		||||
            )
 | 
			
		||||
            # prefill stage vs decoding stage (legacy behavior copied)
 | 
			
		||||
            if input_ids.shape[1] != 1:
 | 
			
		||||
                inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
 | 
			
		||||
                    image_features, inputs_embeds, input_ids, attention_mask, labels
 | 
			
		||||
                )
 | 
			
		||||
                # prefill stage vs decoding stage (legacy behavior copied)
 | 
			
		||||
                if input_ids.shape[1] != 1:
 | 
			
		||||
                    inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
 | 
			
		||||
                        image_features, inputs_embeds, input_ids, attention_mask, labels
 | 
			
		||||
                    )
 | 
			
		||||
                    cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
 | 
			
		||||
                else:
 | 
			
		||||
                    # Retrieve the first layer to inspect the logits and mask out the hidden states
 | 
			
		||||
                    # that are set to 0
 | 
			
		||||
                    first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
 | 
			
		||||
 | 
			
		||||
                    # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
 | 
			
		||||
                    batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
 | 
			
		||||
 | 
			
		||||
                    # Get the target length
 | 
			
		||||
                    target_length = input_ids.shape[1]
 | 
			
		||||
                    past_length = first_layer_past_key_value.shape[-1]
 | 
			
		||||
 | 
			
		||||
                    extended_attention_mask = torch.ones(
 | 
			
		||||
                        (attention_mask.shape[0], past_length),
 | 
			
		||||
                        dtype=attention_mask.dtype,
 | 
			
		||||
                        device=attention_mask.device,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    # Filter out only the tokens that can be un-attended, this can happen
 | 
			
		||||
                    # if one uses Llava + Fused modules where the cache on the
 | 
			
		||||
                    # first iteration is already big enough, or if one passes custom cache
 | 
			
		||||
                    valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
 | 
			
		||||
                    new_batch_index = batch_index[valid_indices]
 | 
			
		||||
                    new_non_attended_tokens = non_attended_tokens[valid_indices]
 | 
			
		||||
 | 
			
		||||
                    # Zero-out the places where we don't need to attend
 | 
			
		||||
                    extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
 | 
			
		||||
 | 
			
		||||
                    attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
 | 
			
		||||
                    position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
 | 
			
		||||
                    cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
 | 
			
		||||
                        -target_length:
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
            # TODO: @raushan retain only the new behavior after v4.47
 | 
			
		||||
                cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
 | 
			
		||||
            else:
 | 
			
		||||
                n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
 | 
			
		||||
                n_image_features = image_features.shape[1]
 | 
			
		||||
                if n_image_tokens != n_image_features:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
                    )
 | 
			
		||||
                special_image_mask = (
 | 
			
		||||
                    (input_ids == self.config.image_token_index)
 | 
			
		||||
                    .unsqueeze(-1)
 | 
			
		||||
                    .expand_as(inputs_embeds)
 | 
			
		||||
                    .to(inputs_embeds.device)
 | 
			
		||||
                # Retrieve the first layer to inspect the logits and mask out the hidden states
 | 
			
		||||
                # that are set to 0
 | 
			
		||||
                first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
 | 
			
		||||
 | 
			
		||||
                # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
 | 
			
		||||
                batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
 | 
			
		||||
 | 
			
		||||
                # Get the target length
 | 
			
		||||
                target_length = input_ids.shape[1]
 | 
			
		||||
                past_length = first_layer_past_key_value.shape[-1]
 | 
			
		||||
 | 
			
		||||
                extended_attention_mask = torch.ones(
 | 
			
		||||
                    (attention_mask.shape[0], past_length),
 | 
			
		||||
                    dtype=attention_mask.dtype,
 | 
			
		||||
                    device=attention_mask.device,
 | 
			
		||||
                )
 | 
			
		||||
                image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
 | 
			
		||||
                inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 | 
			
		||||
 | 
			
		||||
                # Filter out only the tokens that can be un-attended, this can happen
 | 
			
		||||
                # if one uses Llava + Fused modules where the cache on the
 | 
			
		||||
                # first iteration is already big enough, or if one passes custom cache
 | 
			
		||||
                valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
 | 
			
		||||
                new_batch_index = batch_index[valid_indices]
 | 
			
		||||
                new_non_attended_tokens = non_attended_tokens[valid_indices]
 | 
			
		||||
 | 
			
		||||
                # Zero-out the places where we don't need to attend
 | 
			
		||||
                extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
 | 
			
		||||
 | 
			
		||||
                attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
 | 
			
		||||
                position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
 | 
			
		||||
                cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
 | 
			
		||||
 | 
			
		||||
        # TODO: @raushan retain only the new behavior after v4.47
 | 
			
		||||
        elif image_features is not None:
 | 
			
		||||
            n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
            n_image_features = image_features.shape[0] * image_features.shape[1]
 | 
			
		||||
 | 
			
		||||
            if n_image_tokens != n_image_features:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
                )
 | 
			
		||||
            special_image_mask = (
 | 
			
		||||
                (input_ids == self.config.image_token_index)
 | 
			
		||||
                .unsqueeze(-1)
 | 
			
		||||
                .expand_as(inputs_embeds)
 | 
			
		||||
                .to(inputs_embeds.device)
 | 
			
		||||
            )
 | 
			
		||||
            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
 | 
			
		||||
            inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 | 
			
		||||
 | 
			
		||||
        outputs = self.language_model(
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
@ -597,12 +597,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
 | 
			
		||||
    ):
 | 
			
		||||
        # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
 | 
			
		||||
 | 
			
		||||
        # Trigger the new behavior if we have more than image embeddings seq length tokens for images
 | 
			
		||||
        legacy_processing = (
 | 
			
		||||
            input_ids is not None
 | 
			
		||||
            and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        model_inputs = self.language_model.prepare_inputs_for_generation(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
@ -613,7 +607,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if legacy_processing or cache_position[0] == 0:
 | 
			
		||||
        if cache_position[0] == 0:
 | 
			
		||||
            # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
 | 
			
		||||
            # Otherwise we need pixel values to be passed to model
 | 
			
		||||
            model_inputs["pixel_values"] = pixel_values
 | 
			
		||||
 | 
			
		||||
@ -846,6 +846,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
 | 
			
		||||
                (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
 | 
			
		||||
            ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
 | 
			
		||||
 | 
			
		||||
        image_features = None
 | 
			
		||||
        if pixel_values is not None and pixel_values.size(0) > 0:
 | 
			
		||||
            image_features = self.get_image_features(
 | 
			
		||||
                pixel_values,
 | 
			
		||||
@ -861,74 +862,73 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
 | 
			
		||||
                vision_feature_select_strategy=vision_feature_select_strategy,
 | 
			
		||||
                image_newline=self.image_newline,
 | 
			
		||||
            )
 | 
			
		||||
            if legacy_processing:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
 | 
			
		||||
                    "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
 | 
			
		||||
                    "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
 | 
			
		||||
                    "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
 | 
			
		||||
 | 
			
		||||
        if legacy_processing:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
 | 
			
		||||
                "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
 | 
			
		||||
                "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
 | 
			
		||||
                "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
 | 
			
		||||
            )
 | 
			
		||||
            if input_ids.shape[1] != 1:
 | 
			
		||||
                inputs_embeds = inputs_embeds.to(image_features.dtype)
 | 
			
		||||
                inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
 | 
			
		||||
                    image_features,
 | 
			
		||||
                    feature_lens,
 | 
			
		||||
                    inputs_embeds,
 | 
			
		||||
                    input_ids,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    position_ids,
 | 
			
		||||
                    labels=labels,
 | 
			
		||||
                )
 | 
			
		||||
                if input_ids.shape[1] != 1:
 | 
			
		||||
                    inputs_embeds = inputs_embeds.to(image_features.dtype)
 | 
			
		||||
                    inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
 | 
			
		||||
                        image_features,
 | 
			
		||||
                        feature_lens,
 | 
			
		||||
                        inputs_embeds,
 | 
			
		||||
                        input_ids,
 | 
			
		||||
                        attention_mask,
 | 
			
		||||
                        position_ids,
 | 
			
		||||
                        labels=labels,
 | 
			
		||||
                    )
 | 
			
		||||
                    cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
 | 
			
		||||
                else:
 | 
			
		||||
                    # Retrieve the first layer to inspect the logits and mask out the hidden states
 | 
			
		||||
                    # that are set to 0
 | 
			
		||||
                    first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
 | 
			
		||||
 | 
			
		||||
                    # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
 | 
			
		||||
                    batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
 | 
			
		||||
 | 
			
		||||
                    # Get the target length
 | 
			
		||||
                    target_length = input_ids.shape[1]
 | 
			
		||||
                    past_length = first_layer_past_key_value.shape[-1]
 | 
			
		||||
 | 
			
		||||
                    extended_attention_mask = torch.ones(
 | 
			
		||||
                        (attention_mask.shape[0], past_length),
 | 
			
		||||
                        dtype=attention_mask.dtype,
 | 
			
		||||
                        device=attention_mask.device,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    # Filter out only the tokens that can be un-attended, this can happen
 | 
			
		||||
                    # if one uses Llava + Fused modules where the cache on the
 | 
			
		||||
                    # first iteration is already big enough, or if one passes custom cache
 | 
			
		||||
                    valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
 | 
			
		||||
                    new_batch_index = batch_index[valid_indices]
 | 
			
		||||
                    new_non_attended_tokens = non_attended_tokens[valid_indices]
 | 
			
		||||
 | 
			
		||||
                    # Zero-out the places where we don't need to attend
 | 
			
		||||
                    extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
 | 
			
		||||
                    attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
 | 
			
		||||
                    position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
 | 
			
		||||
                    cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
 | 
			
		||||
                        -target_length:
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
            # TODO: @raushan retain only the new behavior after v4.47
 | 
			
		||||
                cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
 | 
			
		||||
            else:
 | 
			
		||||
                n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
                n_image_features = image_features.shape[0]
 | 
			
		||||
                if n_image_tokens != n_image_features:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
                    )
 | 
			
		||||
                special_image_mask = (
 | 
			
		||||
                    (input_ids == self.config.image_token_index)
 | 
			
		||||
                    .unsqueeze(-1)
 | 
			
		||||
                    .expand_as(inputs_embeds)
 | 
			
		||||
                    .to(inputs_embeds.device)
 | 
			
		||||
                # Retrieve the first layer to inspect the logits and mask out the hidden states
 | 
			
		||||
                # that are set to 0
 | 
			
		||||
                first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
 | 
			
		||||
 | 
			
		||||
                # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
 | 
			
		||||
                batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
 | 
			
		||||
 | 
			
		||||
                # Get the target length
 | 
			
		||||
                target_length = input_ids.shape[1]
 | 
			
		||||
                past_length = first_layer_past_key_value.shape[-1]
 | 
			
		||||
 | 
			
		||||
                extended_attention_mask = torch.ones(
 | 
			
		||||
                    (attention_mask.shape[0], past_length),
 | 
			
		||||
                    dtype=attention_mask.dtype,
 | 
			
		||||
                    device=attention_mask.device,
 | 
			
		||||
                )
 | 
			
		||||
                image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
 | 
			
		||||
                inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 | 
			
		||||
 | 
			
		||||
                # Filter out only the tokens that can be un-attended, this can happen
 | 
			
		||||
                # if one uses Llava + Fused modules where the cache on the
 | 
			
		||||
                # first iteration is already big enough, or if one passes custom cache
 | 
			
		||||
                valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
 | 
			
		||||
                new_batch_index = batch_index[valid_indices]
 | 
			
		||||
                new_non_attended_tokens = non_attended_tokens[valid_indices]
 | 
			
		||||
 | 
			
		||||
                # Zero-out the places where we don't need to attend
 | 
			
		||||
                extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
 | 
			
		||||
                attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
 | 
			
		||||
                position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
 | 
			
		||||
                cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
 | 
			
		||||
 | 
			
		||||
        # TODO: @raushan retain only the new behavior after v4.47
 | 
			
		||||
        elif image_features is not None:
 | 
			
		||||
            n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
            n_image_features = image_features.shape[0]
 | 
			
		||||
            if n_image_tokens != n_image_features:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
                )
 | 
			
		||||
            special_image_mask = (
 | 
			
		||||
                (input_ids == self.config.image_token_index)
 | 
			
		||||
                .unsqueeze(-1)
 | 
			
		||||
                .expand_as(inputs_embeds)
 | 
			
		||||
                .to(inputs_embeds.device)
 | 
			
		||||
            )
 | 
			
		||||
            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
 | 
			
		||||
            inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 | 
			
		||||
 | 
			
		||||
        outputs = self.language_model(
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
@ -990,11 +990,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
 | 
			
		||||
    ):
 | 
			
		||||
        # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
 | 
			
		||||
 | 
			
		||||
        legacy_processing = (
 | 
			
		||||
            input_ids is not None
 | 
			
		||||
            and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        model_inputs = self.language_model.prepare_inputs_for_generation(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
@ -1007,7 +1002,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
 | 
			
		||||
 | 
			
		||||
        # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
 | 
			
		||||
        # Otherwise we need pixel values to be passed to model
 | 
			
		||||
        if legacy_processing or cache_position[0] == 0:
 | 
			
		||||
        if cache_position[0] == 0:
 | 
			
		||||
            model_inputs["pixel_values"] = pixel_values
 | 
			
		||||
            model_inputs["image_sizes"] = image_sizes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1020,6 +1020,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
 | 
			
		||||
            if image_features is not None:
 | 
			
		||||
                n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
                n_image_features = image_features.shape[0]
 | 
			
		||||
 | 
			
		||||
                if n_image_tokens != n_image_features:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
@ -1110,17 +1111,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
 | 
			
		||||
    ):
 | 
			
		||||
        # Overwritten -- extra custom processing
 | 
			
		||||
 | 
			
		||||
        if input_ids is not None:
 | 
			
		||||
            img_token_not_enough = (input_ids == self.config.image_token_index).sum(
 | 
			
		||||
                1
 | 
			
		||||
            ).max() < self.config.image_seq_length
 | 
			
		||||
            video_token_not_enough = (input_ids == self.config.video_token_index).sum(
 | 
			
		||||
                1
 | 
			
		||||
            ).max() < self.config.video_seq_length
 | 
			
		||||
            legacy_processing = (img_token_not_enough and pixel_values is not None) or (
 | 
			
		||||
                video_token_not_enough and pixel_values_videos is not None
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        model_inputs = self.language_model.prepare_inputs_for_generation(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
@ -1133,7 +1123,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
 | 
			
		||||
 | 
			
		||||
        # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
 | 
			
		||||
        # Otherwise we need pixel values to be passed to model
 | 
			
		||||
        if legacy_processing or cache_position[0] == 0:
 | 
			
		||||
        if cache_position[0] == 0:
 | 
			
		||||
            model_inputs["pixel_values"] = pixel_values
 | 
			
		||||
            model_inputs["pixel_values_videos"] = pixel_values_videos
 | 
			
		||||
            model_inputs["image_sizes"] = image_sizes
 | 
			
		||||
 | 
			
		||||
@ -533,6 +533,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
 | 
			
		||||
            if image_features is not None:
 | 
			
		||||
                n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
                n_image_features = image_features.shape[0]
 | 
			
		||||
 | 
			
		||||
                if n_image_tokens != n_image_features:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
@ -623,17 +624,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
 | 
			
		||||
    ):
 | 
			
		||||
        # Overwritten -- extra custom processing
 | 
			
		||||
 | 
			
		||||
        if input_ids is not None:
 | 
			
		||||
            img_token_not_enough = (input_ids == self.config.image_token_index).sum(
 | 
			
		||||
                1
 | 
			
		||||
            ).max() < self.config.image_seq_length
 | 
			
		||||
            video_token_not_enough = (input_ids == self.config.video_token_index).sum(
 | 
			
		||||
                1
 | 
			
		||||
            ).max() < self.config.video_seq_length
 | 
			
		||||
            legacy_processing = (img_token_not_enough and pixel_values is not None) or (
 | 
			
		||||
                video_token_not_enough and pixel_values_videos is not None
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        model_inputs = self.language_model.prepare_inputs_for_generation(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
@ -646,7 +636,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
 | 
			
		||||
 | 
			
		||||
        # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
 | 
			
		||||
        # Otherwise we need pixel values to be passed to model
 | 
			
		||||
        if legacy_processing or cache_position[0] == 0:
 | 
			
		||||
        if cache_position[0] == 0:
 | 
			
		||||
            model_inputs["pixel_values"] = pixel_values
 | 
			
		||||
            model_inputs["pixel_values_videos"] = pixel_values_videos
 | 
			
		||||
            model_inputs["image_sizes"] = image_sizes
 | 
			
		||||
 | 
			
		||||
@ -679,6 +679,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
 | 
			
		||||
            )
 | 
			
		||||
            n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
            n_image_features = image_features.shape[0]
 | 
			
		||||
 | 
			
		||||
            if n_image_tokens != n_image_features:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
@ -704,6 +705,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
 | 
			
		||||
            )
 | 
			
		||||
            video_features = torch.cat((video_features, image_newline), dim=1)
 | 
			
		||||
            video_features = video_features.flatten(0, 1)
 | 
			
		||||
 | 
			
		||||
            n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
 | 
			
		||||
            n_video_features = video_features.shape[0]
 | 
			
		||||
            if n_video_tokens != n_video_features:
 | 
			
		||||
 | 
			
		||||
@ -1156,7 +1156,7 @@ class MimiTransformerModel(nn.Module):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -961,7 +961,7 @@ class MistralModel(MistralPreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -1174,7 +1174,7 @@ class MixtralModel(MixtralPreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -1385,7 +1385,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
@ -1689,7 +1689,7 @@ class MoshiModel(MoshiPreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -1139,7 +1139,7 @@ class Phi3Model(Phi3PreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -1305,7 +1305,7 @@ class PhimoeModel(PhimoePreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -762,11 +762,14 @@ class Pix2StructTextAttention(nn.Module):
 | 
			
		||||
        return relative_buckets
 | 
			
		||||
 | 
			
		||||
    # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
 | 
			
		||||
    def compute_bias(self, query_length, key_length, device=None):
 | 
			
		||||
    def compute_bias(self, query_length, key_length, device=None, cache_position=None):
 | 
			
		||||
        """Compute binned relative position bias"""
 | 
			
		||||
        if device is None:
 | 
			
		||||
            device = self.relative_attention_bias.weight.device
 | 
			
		||||
        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
 | 
			
		||||
        if cache_position is None:
 | 
			
		||||
            context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
 | 
			
		||||
        else:
 | 
			
		||||
            context_position = cache_position[:, None].to(device)
 | 
			
		||||
        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
 | 
			
		||||
        relative_position = memory_position - context_position  # shape (query_length, key_length)
 | 
			
		||||
        relative_position_bucket = self._relative_position_bucket(
 | 
			
		||||
@ -779,6 +782,7 @@ class Pix2StructTextAttention(nn.Module):
 | 
			
		||||
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
 | 
			
		||||
        return values
 | 
			
		||||
 | 
			
		||||
    # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states,
 | 
			
		||||
@ -796,61 +800,66 @@ class Pix2StructTextAttention(nn.Module):
 | 
			
		||||
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
 | 
			
		||||
        """
 | 
			
		||||
        # Input is (batch_size, seq_length, dim)
 | 
			
		||||
        # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length)
 | 
			
		||||
        # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
 | 
			
		||||
        batch_size, seq_length = hidden_states.shape[:2]
 | 
			
		||||
 | 
			
		||||
        # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
 | 
			
		||||
        is_cross_attention = key_value_states is not None
 | 
			
		||||
 | 
			
		||||
        query_states = self.query(hidden_states).contiguous()
 | 
			
		||||
        query_states = self.query(hidden_states)
 | 
			
		||||
        query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            is_updated = past_key_value.is_updated.get(self.layer_idx)
 | 
			
		||||
            if is_cross_attention:
 | 
			
		||||
                # after the first generated id, we can subsequently re-use all key/value_states from cache
 | 
			
		||||
                past_key_value = past_key_value.cross_attention_cache
 | 
			
		||||
                curr_past_key_value = past_key_value.cross_attention_cache
 | 
			
		||||
            else:
 | 
			
		||||
                past_key_value = past_key_value.self_attention_cache
 | 
			
		||||
                curr_past_key_value = past_key_value.self_attention_cache
 | 
			
		||||
 | 
			
		||||
        # get key/value states
 | 
			
		||||
        current_states = key_value_states if is_cross_attention else hidden_states
 | 
			
		||||
        if is_cross_attention and past_key_value and is_updated:
 | 
			
		||||
            # reuse k,v, cross_attentions
 | 
			
		||||
            key_states = past_key_value.key_cache[self.layer_idx]
 | 
			
		||||
            value_states = past_key_value.value_cache[self.layer_idx]
 | 
			
		||||
            key_states = curr_past_key_value.key_cache[self.layer_idx]
 | 
			
		||||
            value_states = curr_past_key_value.value_cache[self.layer_idx]
 | 
			
		||||
        else:
 | 
			
		||||
            key_states = self.key(current_states).contiguous()
 | 
			
		||||
            value_states = self.value(current_states).contiguous()
 | 
			
		||||
            key_states = self.key(current_states)
 | 
			
		||||
            value_states = self.value(current_states)
 | 
			
		||||
            key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
 | 
			
		||||
            value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
            if past_key_value is not None:
 | 
			
		||||
                # save all key/value_states to cache to be re-used for fast auto-regressive generation
 | 
			
		||||
                cache_position = cache_position if not is_cross_attention else None
 | 
			
		||||
                key_states, value_states = past_key_value.update(
 | 
			
		||||
                key_states, value_states = curr_past_key_value.update(
 | 
			
		||||
                    key_states, value_states, self.layer_idx, {"cache_position": cache_position}
 | 
			
		||||
                )
 | 
			
		||||
                # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
 | 
			
		||||
                if is_cross_attention:
 | 
			
		||||
                    past_key_value.is_updated[self.layer_idx] = True
 | 
			
		||||
 | 
			
		||||
        # compute scores
 | 
			
		||||
        # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
 | 
			
		||||
        scores = torch.matmul(query_states, key_states.transpose(3, 2))
 | 
			
		||||
 | 
			
		||||
        if position_bias is None:
 | 
			
		||||
            real_seq_length = cache_position[-1] + 1 if query_length is None else query_length
 | 
			
		||||
            key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
 | 
			
		||||
            key_length = key_states.shape[-2]
 | 
			
		||||
            # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
 | 
			
		||||
            real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
 | 
			
		||||
            if not self.has_relative_attention_bias:
 | 
			
		||||
                position_bias = torch.zeros(
 | 
			
		||||
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
 | 
			
		||||
                    (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
 | 
			
		||||
                )
 | 
			
		||||
                if self.gradient_checkpointing and self.training:
 | 
			
		||||
                    position_bias.requires_grad = True
 | 
			
		||||
            else:
 | 
			
		||||
                position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
 | 
			
		||||
                position_bias = self.compute_bias(
 | 
			
		||||
                    real_seq_length, key_length, device=scores.device, cache_position=cache_position
 | 
			
		||||
                )
 | 
			
		||||
                position_bias = position_bias[:, :, -seq_length:, :]
 | 
			
		||||
 | 
			
		||||
            if mask is not None:
 | 
			
		||||
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
 | 
			
		||||
                causal_mask = mask[:, :, :, : key_states.shape[-2]]
 | 
			
		||||
                position_bias = position_bias + causal_mask
 | 
			
		||||
 | 
			
		||||
        if self.pruned_heads:
 | 
			
		||||
            mask = torch.ones(position_bias.shape[1])
 | 
			
		||||
@ -860,10 +869,9 @@ class Pix2StructTextAttention(nn.Module):
 | 
			
		||||
            position_bias_masked = position_bias
 | 
			
		||||
 | 
			
		||||
        scores += position_bias_masked
 | 
			
		||||
        # (batch_size, n_heads, seq_length, key_length)
 | 
			
		||||
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
 | 
			
		||||
 | 
			
		||||
        # (batch_size, n_heads, seq_length, key_length)
 | 
			
		||||
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
 | 
			
		||||
        attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 | 
			
		||||
 | 
			
		||||
        # Mask heads if we want to
 | 
			
		||||
@ -871,12 +879,12 @@ class Pix2StructTextAttention(nn.Module):
 | 
			
		||||
            attn_weights = attn_weights * layer_head_mask
 | 
			
		||||
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        # (batch_size, seq_length, dim)
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.view(batch_size, -1, self.inner_dim)
 | 
			
		||||
        attn_output = self.output(attn_output)
 | 
			
		||||
 | 
			
		||||
        outputs = (attn_output,) + (past_key_value,) + (position_bias,)
 | 
			
		||||
        outputs = (attn_output, past_key_value, position_bias)
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            outputs = outputs + (attn_weights,)
 | 
			
		||||
@ -969,7 +977,10 @@ class Pix2StructTextBlock(nn.Module):
 | 
			
		||||
            layer_idx=layer_idx,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
 | 
			
		||||
        self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
 | 
			
		||||
            config,
 | 
			
		||||
            layer_idx=layer_idx,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.mlp = Pix2StructTextLayerFF(config)
 | 
			
		||||
 | 
			
		||||
@ -1019,7 +1030,6 @@ class Pix2StructTextBlock(nn.Module):
 | 
			
		||||
                query_length=cache_position[-1] + 1,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                cache_position=cache_position,
 | 
			
		||||
            )
 | 
			
		||||
            hidden_states, past_key_value = cross_attention_outputs[:2]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,6 +52,8 @@ class PixtralVisionConfig(PretrainedConfig):
 | 
			
		||||
            Dropout probability for the attention layers.
 | 
			
		||||
        rope_theta (`float`, *optional*, defaults to 10000.0):
 | 
			
		||||
            The base period of the RoPE embeddings.
 | 
			
		||||
        initializer_range (`float`, *optional*, defaults to 0.02):
 | 
			
		||||
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
 | 
			
		||||
@ -82,6 +84,7 @@ class PixtralVisionConfig(PretrainedConfig):
 | 
			
		||||
        hidden_act="gelu",
 | 
			
		||||
        attention_dropout=0.0,
 | 
			
		||||
        rope_theta=10000.0,
 | 
			
		||||
        initializer_range=0.02,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(**kwargs)
 | 
			
		||||
@ -97,3 +100,4 @@ class PixtralVisionConfig(PretrainedConfig):
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.rope_theta = rope_theta
 | 
			
		||||
        self.head_dim = hidden_size // num_attention_heads
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
 | 
			
		||||
@ -407,7 +407,7 @@ class PixtralPreTrainedModel(PreTrainedModel):
 | 
			
		||||
        std = (
 | 
			
		||||
            self.config.initializer_range
 | 
			
		||||
            if hasattr(self.config, "initializer_range")
 | 
			
		||||
            else self.config.text_config.initializer_range
 | 
			
		||||
            else self.config.initializer_range
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv2d)):
 | 
			
		||||
 | 
			
		||||
@ -206,14 +206,15 @@ class PixtralProcessor(ProcessorMixin):
 | 
			
		||||
            if is_image_or_image_url(images):
 | 
			
		||||
                images = [[images]]
 | 
			
		||||
            elif isinstance(images, list) and is_image_or_image_url(images[0]):
 | 
			
		||||
                images = [images]
 | 
			
		||||
            elif (
 | 
			
		||||
                not isinstance(images, list)
 | 
			
		||||
                and not isinstance(images[0], list)
 | 
			
		||||
                and not is_image_or_image_url(images[0][0])
 | 
			
		||||
            ):
 | 
			
		||||
                if isinstance(text, list):
 | 
			
		||||
                    images = [[im] for im in images]
 | 
			
		||||
                else:
 | 
			
		||||
                    images = [images]
 | 
			
		||||
            elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Invalid input images. Please provide a single image or a list of images or a list of list of images."
 | 
			
		||||
                    "Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
 | 
			
		||||
                )
 | 
			
		||||
            images = [[load_image(im) for im in sample] for sample in images]
 | 
			
		||||
            image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
 | 
			
		||||
 | 
			
		||||
@ -1059,7 +1059,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -1239,7 +1239,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -1321,7 +1321,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
@ -1503,13 +1503,14 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
 | 
			
		||||
        mrope_position_deltas = []
 | 
			
		||||
        if image_grid_thw is not None or video_grid_thw is not None:
 | 
			
		||||
            total_input_ids = input_ids
 | 
			
		||||
            if attention_mask is None:
 | 
			
		||||
                attention_mask = torch.ones_like(total_input_ids)
 | 
			
		||||
            position_ids = torch.ones(
 | 
			
		||||
                3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
 | 
			
		||||
            )
 | 
			
		||||
            image_index, video_index = 0, 0
 | 
			
		||||
            for i, input_ids in enumerate(total_input_ids):
 | 
			
		||||
                if attention_mask is not None:
 | 
			
		||||
                    input_ids = input_ids[attention_mask[i] == 1]
 | 
			
		||||
                input_ids = input_ids[attention_mask[i] == 1]
 | 
			
		||||
                image_nums, video_nums = 0, 0
 | 
			
		||||
                vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
 | 
			
		||||
                vision_tokens = input_ids[vision_start_indices + 1]
 | 
			
		||||
 | 
			
		||||
@ -1033,7 +1033,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
 | 
			
		||||
                    sliding_attend_mask = torch.arange(target_length, device=device) <= (
 | 
			
		||||
                        cache_position.reshape(-1, 1) - config.sliding_window
 | 
			
		||||
                    )
 | 
			
		||||
                    diagonal_attend_mask |= sliding_attend_mask
 | 
			
		||||
                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
 | 
			
		||||
            causal_mask *= diagonal_attend_mask
 | 
			
		||||
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
 | 
			
		||||
@ -622,8 +622,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
 | 
			
		||||
        # TODO: @raushan retain only the new behavior after v4.47
 | 
			
		||||
        else:
 | 
			
		||||
            if pixel_values_images is not None:
 | 
			
		||||
                n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
 | 
			
		||||
                n_image_features = image_features.shape[1]
 | 
			
		||||
                n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
                n_image_features = image_features.shape[0] * image_features.shape[1]
 | 
			
		||||
                if n_image_tokens != n_image_features:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
@ -638,8 +638,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
 | 
			
		||||
                inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 | 
			
		||||
 | 
			
		||||
            if pixel_values_videos is not None:
 | 
			
		||||
                n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
 | 
			
		||||
                n_video_features = video_features.shape[1]
 | 
			
		||||
                n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
 | 
			
		||||
                n_video_features = video_features.shape[0] * video_features.shape[1]
 | 
			
		||||
                if n_video_tokens != n_video_features:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
 | 
			
		||||
@ -714,17 +714,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
 | 
			
		||||
    ):
 | 
			
		||||
        # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
 | 
			
		||||
 | 
			
		||||
        if input_ids is not None:
 | 
			
		||||
            img_token_not_enough = (input_ids == self.config.image_token_index).sum(
 | 
			
		||||
                1
 | 
			
		||||
            ).max() < self.config.image_seq_length
 | 
			
		||||
            video_token_not_enough = (input_ids == self.config.video_token_index).sum(
 | 
			
		||||
                1
 | 
			
		||||
            ).max() < self.config.video_seq_length
 | 
			
		||||
            legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
 | 
			
		||||
                video_token_not_enough and pixel_values_videos is not None
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        model_inputs = self.language_model.prepare_inputs_for_generation(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
@ -735,7 +724,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if legacy_processing or cache_position[0] == 0:
 | 
			
		||||
        if cache_position[0] == 0:
 | 
			
		||||
            # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
 | 
			
		||||
            # Otherwise we need pixel values to be passed to model
 | 
			
		||||
            model_inputs["pixel_values_images"] = pixel_values_images
 | 
			
		||||
 | 
			
		||||
@ -461,72 +461,71 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
 | 
			
		||||
                (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
 | 
			
		||||
            ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
 | 
			
		||||
 | 
			
		||||
        image_features = None
 | 
			
		||||
        if pixel_values is not None:
 | 
			
		||||
            image_features = self.get_image_features(
 | 
			
		||||
                pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if legacy_processing:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "Expanding inputs for image tokens in VipLLaVa should be done in processing. "
 | 
			
		||||
                    "Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
 | 
			
		||||
                    "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
 | 
			
		||||
        if legacy_processing:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "Expanding inputs for image tokens in VipLLaVa should be done in processing. "
 | 
			
		||||
                "Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
 | 
			
		||||
                "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
 | 
			
		||||
            )
 | 
			
		||||
            # prefill stage vs decoding stage (legacy behavior copied)
 | 
			
		||||
            if input_ids.shape[1] != 1:
 | 
			
		||||
                inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
 | 
			
		||||
                    image_features, inputs_embeds, input_ids, attention_mask, labels
 | 
			
		||||
                )
 | 
			
		||||
                # prefill stage vs decoding stage (legacy behavior copied)
 | 
			
		||||
                if input_ids.shape[1] != 1:
 | 
			
		||||
                    inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
 | 
			
		||||
                        image_features, inputs_embeds, input_ids, attention_mask, labels
 | 
			
		||||
                    )
 | 
			
		||||
                    cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
 | 
			
		||||
                else:
 | 
			
		||||
                    # Retrieve the first layer to inspect the logits and mask out the hidden states
 | 
			
		||||
                    # that are set to 0
 | 
			
		||||
                    first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
 | 
			
		||||
 | 
			
		||||
                    # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
 | 
			
		||||
                    batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
 | 
			
		||||
 | 
			
		||||
                    target_length = input_ids.shape[1]
 | 
			
		||||
                    past_length = first_layer_past_key_value.shape[-1]
 | 
			
		||||
 | 
			
		||||
                    extended_attention_mask = torch.ones(
 | 
			
		||||
                        (attention_mask.shape[0], past_length),
 | 
			
		||||
                        dtype=attention_mask.dtype,
 | 
			
		||||
                        device=attention_mask.device,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    # Filter out only the tokens that can be un-attended, this can happen
 | 
			
		||||
                    # in the case one uses Llava + Fused modules where the cache on the
 | 
			
		||||
                    # first iteration is already big enough, or if one passes custom cache
 | 
			
		||||
                    valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
 | 
			
		||||
                    new_batch_index = batch_index[valid_indices]
 | 
			
		||||
                    new_non_attended_tokens = non_attended_tokens[valid_indices]
 | 
			
		||||
 | 
			
		||||
                    # Zero-out the places where we don't need to attend
 | 
			
		||||
                    extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
 | 
			
		||||
 | 
			
		||||
                    attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
 | 
			
		||||
                    position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
 | 
			
		||||
                    cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
 | 
			
		||||
                        -target_length:
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
            # TODO: @raushan retain only the new behavior after v4.47
 | 
			
		||||
                cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
 | 
			
		||||
            else:
 | 
			
		||||
                n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
 | 
			
		||||
                n_image_features = image_features.shape[1]
 | 
			
		||||
                if n_image_tokens != n_image_features:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
                    )
 | 
			
		||||
                special_image_mask = (
 | 
			
		||||
                    (input_ids == self.config.image_token_index)
 | 
			
		||||
                    .unsqueeze(-1)
 | 
			
		||||
                    .expand_as(inputs_embeds)
 | 
			
		||||
                    .to(inputs_embeds.device)
 | 
			
		||||
                # Retrieve the first layer to inspect the logits and mask out the hidden states
 | 
			
		||||
                # that are set to 0
 | 
			
		||||
                first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
 | 
			
		||||
 | 
			
		||||
                # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
 | 
			
		||||
                batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
 | 
			
		||||
 | 
			
		||||
                target_length = input_ids.shape[1]
 | 
			
		||||
                past_length = first_layer_past_key_value.shape[-1]
 | 
			
		||||
 | 
			
		||||
                extended_attention_mask = torch.ones(
 | 
			
		||||
                    (attention_mask.shape[0], past_length),
 | 
			
		||||
                    dtype=attention_mask.dtype,
 | 
			
		||||
                    device=attention_mask.device,
 | 
			
		||||
                )
 | 
			
		||||
                image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
 | 
			
		||||
                inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 | 
			
		||||
 | 
			
		||||
                # Filter out only the tokens that can be un-attended, this can happen
 | 
			
		||||
                # in the case one uses Llava + Fused modules where the cache on the
 | 
			
		||||
                # first iteration is already big enough, or if one passes custom cache
 | 
			
		||||
                valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
 | 
			
		||||
                new_batch_index = batch_index[valid_indices]
 | 
			
		||||
                new_non_attended_tokens = non_attended_tokens[valid_indices]
 | 
			
		||||
 | 
			
		||||
                # Zero-out the places where we don't need to attend
 | 
			
		||||
                extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
 | 
			
		||||
 | 
			
		||||
                attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
 | 
			
		||||
                position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
 | 
			
		||||
                cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
 | 
			
		||||
 | 
			
		||||
        # TODO: @raushan retain only the new behavior after v4.47
 | 
			
		||||
        elif image_features is not None:
 | 
			
		||||
            n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
 | 
			
		||||
            n_image_features = image_features.shape[0] * image_features.shape[1]
 | 
			
		||||
            if n_image_tokens != n_image_features:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
 | 
			
		||||
                )
 | 
			
		||||
            special_image_mask = (
 | 
			
		||||
                (input_ids == self.config.image_token_index)
 | 
			
		||||
                .unsqueeze(-1)
 | 
			
		||||
                .expand_as(inputs_embeds)
 | 
			
		||||
                .to(inputs_embeds.device)
 | 
			
		||||
            )
 | 
			
		||||
            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
 | 
			
		||||
            inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
 | 
			
		||||
 | 
			
		||||
        outputs = self.language_model(
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
@ -585,12 +584,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
 | 
			
		||||
    ):
 | 
			
		||||
        # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
 | 
			
		||||
 | 
			
		||||
        # Trigger the new behavior if we have more than image embeddings seq length tokens for images
 | 
			
		||||
        legacy_processing = (
 | 
			
		||||
            input_ids is not None
 | 
			
		||||
            and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        model_inputs = self.language_model.prepare_inputs_for_generation(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
@ -601,7 +594,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if legacy_processing or cache_position[0] == 0:
 | 
			
		||||
        if cache_position[0] == 0:
 | 
			
		||||
            # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
 | 
			
		||||
            # Otherwise we need pixel values to be passed to model
 | 
			
		||||
            model_inputs["pixel_values"] = pixel_values
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@
 | 
			
		||||
import datetime
 | 
			
		||||
import platform
 | 
			
		||||
import subprocess
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
from typing import Optional, Tuple, Union, List
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ def ffmpeg_microphone(
 | 
			
		||||
    chunk_length_s: float,
 | 
			
		||||
    format_for_conversion: str = "f32le",
 | 
			
		||||
    ffmpeg_input_device: Optional[str] = None,
 | 
			
		||||
    ffmpeg_additional_args: Optional[list[str]] = None,
 | 
			
		||||
    ffmpeg_additional_args: Optional[List[str]] = None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function to read audio from a microphone using ffmpeg. The default input device will be used unless another
 | 
			
		||||
@ -138,7 +138,7 @@ def ffmpeg_microphone_live(
 | 
			
		||||
    stride_length_s: Optional[Union[Tuple[float, float], float]] = None,
 | 
			
		||||
    format_for_conversion: str = "f32le",
 | 
			
		||||
    ffmpeg_input_device: Optional[str] = None,
 | 
			
		||||
    ffmpeg_additional_args: Optional[list[str]] = None,
 | 
			
		||||
    ffmpeg_additional_args: Optional[List[str]] = None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function to read audio from a microphone using ffmpeg. This will output `partial` overlapping chunks starting
 | 
			
		||||
 | 
			
		||||
@ -314,7 +314,7 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        elements (`torch.Tensor`): Input elements
 | 
			
		||||
        test_elements (`torch.Tensor`): The elements to check against.
 | 
			
		||||
        test_elements (`torch.Tensor` or `int`): The elements to check against.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
 | 
			
		||||
@ -322,6 +322,9 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
 | 
			
		||||
        test_elements = torch.tensor(test_elements)
 | 
			
		||||
        if test_elements.ndim == 0:
 | 
			
		||||
            test_elements = test_elements.unsqueeze(0)
 | 
			
		||||
        return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
 | 
			
		||||
    else:
 | 
			
		||||
        # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
 | 
			
		||||
 | 
			
		||||
@ -233,7 +233,6 @@ if is_accelerate_available():
 | 
			
		||||
    from accelerate.utils import (
 | 
			
		||||
        DistributedDataParallelKwargs,
 | 
			
		||||
        DistributedType,
 | 
			
		||||
        GradientAccumulationPlugin,
 | 
			
		||||
        load_fsdp_model,
 | 
			
		||||
        load_fsdp_optimizer,
 | 
			
		||||
        save_fsdp_model,
 | 
			
		||||
@ -589,8 +588,10 @@ class Trainer:
 | 
			
		||||
            if not _is_peft_model(unwrapped_model)
 | 
			
		||||
            else unwrapped_model.get_base_model().forward
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters
 | 
			
		||||
        forward_params = inspect.signature(model_forward).parameters
 | 
			
		||||
        self.model_accepts_loss_kwargs = (
 | 
			
		||||
            "loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.neftune_noise_alpha = args.neftune_noise_alpha
 | 
			
		||||
 | 
			
		||||
@ -2424,7 +2425,7 @@ class Trainer:
 | 
			
		||||
                update_step += 1
 | 
			
		||||
                num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
 | 
			
		||||
                batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
 | 
			
		||||
                for inputs in batch_samples:
 | 
			
		||||
                for i, inputs in enumerate(batch_samples):
 | 
			
		||||
                    step += 1
 | 
			
		||||
                    total_batched_samples += 1
 | 
			
		||||
                    is_last_step_and_steps_less_than_grad_acc = (
 | 
			
		||||
@ -2470,7 +2471,13 @@ class Trainer:
 | 
			
		||||
                    if step % args.gradient_accumulation_steps == 0:
 | 
			
		||||
                        self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
 | 
			
		||||
 | 
			
		||||
                    with self.accelerator.accumulate(model):
 | 
			
		||||
                    # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
 | 
			
		||||
                    context = (
 | 
			
		||||
                        functools.partial(self.accelerator.no_sync, model=model)
 | 
			
		||||
                        if i == len(batch_samples) - 1
 | 
			
		||||
                        else contextlib.nullcontext
 | 
			
		||||
                    )
 | 
			
		||||
                    with context():
 | 
			
		||||
                        tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
 | 
			
		||||
 | 
			
		||||
                    if (
 | 
			
		||||
@ -3602,10 +3609,11 @@ class Trainer:
 | 
			
		||||
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
 | 
			
		||||
                scaled_loss.backward()
 | 
			
		||||
        else:
 | 
			
		||||
            loss *= self.args.gradient_accumulation_steps
 | 
			
		||||
            self.accelerator.backward(loss, **kwargs)
 | 
			
		||||
 | 
			
		||||
        return loss.detach() / self.args.gradient_accumulation_steps
 | 
			
		||||
            # Finally we need to normalize the loss for reporting
 | 
			
		||||
            if num_items_in_batch is None:
 | 
			
		||||
                return loss.detach() / self.args.gradient_accumulation_steps
 | 
			
		||||
            return loss.detach()
 | 
			
		||||
 | 
			
		||||
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
 | 
			
		||||
        """
 | 
			
		||||
@ -3650,6 +3658,9 @@ class Trainer:
 | 
			
		||||
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
 | 
			
		||||
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
 | 
			
		||||
 | 
			
		||||
        if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
 | 
			
		||||
            loss *= self.accelerator.num_processes
 | 
			
		||||
 | 
			
		||||
        return (loss, outputs) if return_outputs else loss
 | 
			
		||||
 | 
			
		||||
    def is_local_process_zero(self) -> bool:
 | 
			
		||||
@ -4894,24 +4905,21 @@ class Trainer:
 | 
			
		||||
            self.repo.git_push()
 | 
			
		||||
 | 
			
		||||
    def create_accelerator_and_postprocess(self):
 | 
			
		||||
        # We explicitly don't rely on the `Accelerator` to do gradient accumulation
 | 
			
		||||
        grad_acc_kwargs = {}
 | 
			
		||||
        if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
 | 
			
		||||
            grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
 | 
			
		||||
 | 
			
		||||
        # check if num_steps is attempted to be passed in gradient_accumulation_kwargs
 | 
			
		||||
        if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
 | 
			
		||||
            # raise because we do not know which setting is intended.
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
 | 
			
		||||
                "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
 | 
			
		||||
            )
 | 
			
		||||
        elif "num_steps" not in grad_acc_kwargs:
 | 
			
		||||
            # take the gradient_accumulation_steps setting from TrainingArguments.
 | 
			
		||||
            grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
 | 
			
		||||
 | 
			
		||||
        grad_acc_kwargs["sync_with_dataloader"] = False
 | 
			
		||||
 | 
			
		||||
        gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
 | 
			
		||||
        if "num_steps" in grad_acc_kwargs:
 | 
			
		||||
            if self.args.gradient_accumulation_steps > 1:
 | 
			
		||||
                # raise because we do not know which setting is intended.
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
 | 
			
		||||
                    "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
 | 
			
		||||
 | 
			
		||||
        accelerator_config = self.args.accelerator_config.to_dict()
 | 
			
		||||
 | 
			
		||||
@ -4942,7 +4950,6 @@ class Trainer:
 | 
			
		||||
 | 
			
		||||
        args = {
 | 
			
		||||
            "deepspeed_plugin": self.args.deepspeed_plugin,
 | 
			
		||||
            "gradient_accumulation_plugin": gradient_accumulation_plugin,
 | 
			
		||||
        }
 | 
			
		||||
        if is_accelerate_available("0.28.0"):
 | 
			
		||||
            args["dataloader_config"] = dataloader_config
 | 
			
		||||
@ -5038,12 +5045,18 @@ class Trainer:
 | 
			
		||||
                batch_samples += [next(epoch_iterator)]
 | 
			
		||||
            except StopIteration:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        # Keep default behavior the same
 | 
			
		||||
        if not self.model_accepts_loss_kwargs:
 | 
			
		||||
            return batch_samples, None
 | 
			
		||||
 | 
			
		||||
        if len(batch_samples) > 0 and "labels" in batch_samples[0]:
 | 
			
		||||
            # For now we don't support object detection
 | 
			
		||||
            try:
 | 
			
		||||
                num_items_in_batch = sum(
 | 
			
		||||
                    [data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
 | 
			
		||||
                )
 | 
			
		||||
            except TypeError:
 | 
			
		||||
                num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
 | 
			
		||||
            except (TypeError, AttributeError):
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        if self.args.average_tokens_across_devices:
 | 
			
		||||
            num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
 | 
			
		||||
        return batch_samples, num_items_in_batch
 | 
			
		||||
 | 
			
		||||
@ -1530,6 +1530,15 @@ class TrainingArguments:
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    average_tokens_across_devices: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
 | 
			
		||||
            "synchronize num_tokens_in_batch for precise loss calculation. Reference: "
 | 
			
		||||
            "https://github.com/huggingface/transformers/issues/34242"
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        # Parse in args that could be `dict` sent in from the CLI as a string
 | 
			
		||||
        for field in _VALID_DICT_FIELDS:
 | 
			
		||||
@ -1763,6 +1772,19 @@ class TrainingArguments:
 | 
			
		||||
        if self.framework == "pt" and is_torch_available():
 | 
			
		||||
            self.device
 | 
			
		||||
 | 
			
		||||
        # Disable average tokens when using single device
 | 
			
		||||
        if self.average_tokens_across_devices:
 | 
			
		||||
            try:
 | 
			
		||||
                if self.world_size == 1:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        "average_tokens_across_devices is set to True but it is invalid when world size is"
 | 
			
		||||
                        "1. Turn it to False automatically."
 | 
			
		||||
                    )
 | 
			
		||||
                    self.average_tokens_across_devices = False
 | 
			
		||||
            except ImportError as e:
 | 
			
		||||
                logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
 | 
			
		||||
                self.average_tokens_across_devices = False
 | 
			
		||||
 | 
			
		||||
        if self.torchdynamo is not None:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
 | 
			
		||||
 | 
			
		||||
@ -1416,7 +1416,7 @@ class HFTracer(Tracer):
 | 
			
		||||
        your custom tracer.
 | 
			
		||||
        """
 | 
			
		||||
        attribute = HFAttribute(obj, "keys")()
 | 
			
		||||
        if obj.node.target == "**kwargs":
 | 
			
		||||
        if obj.node.target.startswith("**"):
 | 
			
		||||
            return attribute._metadata
 | 
			
		||||
        return attribute
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -304,7 +304,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
 | 
			
		||||
            config_and_inputs[0].position_embedding_type = type
 | 
			
		||||
            self.model_tester.create_and_check_model(*config_and_inputs)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(reason="PR #34283 made changes to the forward function.")
 | 
			
		||||
    def test_torch_fx_output_loss(self):
 | 
			
		||||
        super().test_torch_fx_output_loss()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -235,6 +235,35 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
 | 
			
		||||
                out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
 | 
			
		||||
            self.assertTrue(torch.allclose(out_embeds, out_ids))
 | 
			
		||||
 | 
			
		||||
    def test_mismatching_num_image_tokens(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that VLMs through an error with explicit message saying what is wrong
 | 
			
		||||
        when number of images don't match number of image tokens in the text.
 | 
			
		||||
        Also we need to test multi-image cases when one prompr has multiple image tokens.
 | 
			
		||||
        """
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_model_classes:
 | 
			
		||||
            model = model_class(config).to(torch_device)
 | 
			
		||||
            _ = model(**input_dict)  # successfull forward with no modifications
 | 
			
		||||
 | 
			
		||||
            # remove one image but leave the image token in text
 | 
			
		||||
            input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**input_dict)
 | 
			
		||||
 | 
			
		||||
            # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
 | 
			
		||||
            input_ids = input_dict["input_ids"][:1]
 | 
			
		||||
            pixel_values = input_dict["pixel_values"][:1]
 | 
			
		||||
            input_ids = torch.cat([input_ids, input_ids], dim=0)
 | 
			
		||||
 | 
			
		||||
            # one image and two image tokens raise an error
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(input_ids=input_ids, pixel_values=pixel_values)
 | 
			
		||||
 | 
			
		||||
            # two images and two image tokens don't raise an error
 | 
			
		||||
            pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
 | 
			
		||||
            _ = model(input_ids=input_ids, pixel_values=pixel_values)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -283,6 +283,38 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
 | 
			
		||||
                out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
 | 
			
		||||
            self.assertTrue(torch.allclose(out_embeds, out_ids))
 | 
			
		||||
 | 
			
		||||
    def test_mismatching_num_image_tokens(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that VLMs through an error with explicit message saying what is wrong
 | 
			
		||||
        when number of images don't match number of image tokens in the text.
 | 
			
		||||
        Also we need to test multi-image cases when one prompr has multiple image tokens.
 | 
			
		||||
        """
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_model_classes:
 | 
			
		||||
            model = model_class(config).to(torch_device)
 | 
			
		||||
            _ = model(**input_dict)  # successfull forward with no modifications
 | 
			
		||||
 | 
			
		||||
            # remove one image but leave the image token in text
 | 
			
		||||
            input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
 | 
			
		||||
            input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**input_dict)
 | 
			
		||||
 | 
			
		||||
            # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
 | 
			
		||||
            input_ids = input_dict["input_ids"][:1]
 | 
			
		||||
            pixel_values = input_dict["pixel_values"][:1]
 | 
			
		||||
            image_sizes = input_dict["image_sizes"][:1]
 | 
			
		||||
            input_ids = torch.cat([input_ids, input_ids], dim=0)
 | 
			
		||||
 | 
			
		||||
            # one image and two image tokens raise an error
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
 | 
			
		||||
 | 
			
		||||
            # two images and two image tokens don't raise an error
 | 
			
		||||
            pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
 | 
			
		||||
            image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
 | 
			
		||||
            _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -303,6 +303,38 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
 | 
			
		||||
                out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
 | 
			
		||||
            self.assertTrue(torch.allclose(out_embeds, out_ids))
 | 
			
		||||
 | 
			
		||||
    def test_mismatching_num_image_tokens(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that VLMs through an error with explicit message saying what is wrong
 | 
			
		||||
        when number of images don't match number of image tokens in the text.
 | 
			
		||||
        Also we need to test multi-image cases when one prompr has multiple image tokens.
 | 
			
		||||
        """
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_model_classes:
 | 
			
		||||
            model = model_class(config).to(torch_device)
 | 
			
		||||
            _ = model(**input_dict)  # successfull forward with no modifications
 | 
			
		||||
 | 
			
		||||
            # remove one image but leave the image token in text
 | 
			
		||||
            input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
 | 
			
		||||
            input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**input_dict)
 | 
			
		||||
 | 
			
		||||
            # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
 | 
			
		||||
            input_ids = input_dict["input_ids"][:1]
 | 
			
		||||
            pixel_values = input_dict["pixel_values"][:1]
 | 
			
		||||
            image_sizes = input_dict["image_sizes"][:1]
 | 
			
		||||
            input_ids = torch.cat([input_ids, input_ids], dim=0)
 | 
			
		||||
 | 
			
		||||
            # one image and two image tokens raise an error
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
 | 
			
		||||
 | 
			
		||||
            # two images and two image tokens don't raise an error
 | 
			
		||||
            pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
 | 
			
		||||
            image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
 | 
			
		||||
            _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -356,7 +356,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
 | 
			
		||||
            config_and_inputs[0].position_embedding_type = type
 | 
			
		||||
            self.model_tester.create_and_check_model(*config_and_inputs)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(reason="PR #34283 made changes to the forward function.")
 | 
			
		||||
    def test_torch_fx_output_loss(self):
 | 
			
		||||
        super().test_torch_fx_output_loss()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -356,7 +356,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
 | 
			
		||||
            config_and_inputs[0].position_embedding_type = type
 | 
			
		||||
            self.model_tester.create_and_check_model(*config_and_inputs)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(reason="PR #34283 made changes to the forward function.")
 | 
			
		||||
    def test_torch_fx_output_loss(self):
 | 
			
		||||
        super().test_torch_fx_output_loss()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -236,6 +236,36 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
 | 
			
		||||
                out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
 | 
			
		||||
            self.assertTrue(torch.allclose(out_embeds, out_ids))
 | 
			
		||||
 | 
			
		||||
    # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
 | 
			
		||||
    def test_mismatching_num_image_tokens(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that VLMs through an error with explicit message saying what is wrong
 | 
			
		||||
        when number of images don't match number of image tokens in the text.
 | 
			
		||||
        Also we need to test multi-image cases when one prompr has multiple image tokens.
 | 
			
		||||
        """
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_model_classes:
 | 
			
		||||
            model = model_class(config).to(torch_device)
 | 
			
		||||
            _ = model(**input_dict)  # successfull forward with no modifications
 | 
			
		||||
 | 
			
		||||
            # remove one image but leave the image token in text
 | 
			
		||||
            input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**input_dict)
 | 
			
		||||
 | 
			
		||||
            # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
 | 
			
		||||
            input_ids = input_dict["input_ids"][:1]
 | 
			
		||||
            pixel_values = input_dict["pixel_values"][:1]
 | 
			
		||||
            input_ids = torch.cat([input_ids, input_ids], dim=0)
 | 
			
		||||
 | 
			
		||||
            # one image and two image tokens raise an error
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(input_ids=input_ids, pixel_values=pixel_values)
 | 
			
		||||
 | 
			
		||||
            # two images and two image tokens don't raise an error
 | 
			
		||||
            pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
 | 
			
		||||
            _ = model(input_ids=input_ids, pixel_values=pixel_values)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -419,6 +419,7 @@ class Pix2StructModelTester:
 | 
			
		||||
@require_torch
 | 
			
		||||
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
 | 
			
		||||
    all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
 | 
			
		||||
    all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
 | 
			
		||||
    pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
 | 
			
		||||
    fx_compatible = False
 | 
			
		||||
    test_head_masking = False
 | 
			
		||||
@ -445,6 +446,16 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_generative_model(self):
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_generative_model_classes:
 | 
			
		||||
            model = model_class(config).eval().to(torch_device)
 | 
			
		||||
 | 
			
		||||
            output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
 | 
			
		||||
            output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)
 | 
			
		||||
 | 
			
		||||
            torch.testing.assert_close(output, output_use_cache)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(reason="Hidden_states is tested in individual model tests")
 | 
			
		||||
    def test_hidden_states_output(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
@ -14,22 +14,16 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Testing suite for the PyTorch Pixtral model."""
 | 
			
		||||
 | 
			
		||||
import gc
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoProcessor,
 | 
			
		||||
    PixtralVisionConfig,
 | 
			
		||||
    PixtralVisionModel,
 | 
			
		||||
    is_torch_available,
 | 
			
		||||
    is_vision_available,
 | 
			
		||||
)
 | 
			
		||||
from transformers.testing_utils import (
 | 
			
		||||
    require_bitsandbytes,
 | 
			
		||||
    require_torch,
 | 
			
		||||
    slow,
 | 
			
		||||
    torch_device,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -43,7 +37,7 @@ else:
 | 
			
		||||
    is_torch_greater_or_equal_than_2_0 = False
 | 
			
		||||
 | 
			
		||||
if is_vision_available():
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PixtralVisionModelTester:
 | 
			
		||||
@ -148,6 +142,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
 | 
			
		||||
    all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
 | 
			
		||||
    test_pruning = False
 | 
			
		||||
    test_head_masking = False
 | 
			
		||||
    test_torchscript = False
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.model_tester = PixtralVisionModelTester(self)
 | 
			
		||||
@ -258,35 +253,3 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
 | 
			
		||||
    @unittest.skip(reason="Not supported yet")
 | 
			
		||||
    def test_determinism(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@require_torch
 | 
			
		||||
class PixtralVisionModelIntegrationTest(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
    @slow
 | 
			
		||||
    @require_bitsandbytes
 | 
			
		||||
    def test_small_model_integration_test(self):
 | 
			
		||||
        # Let' s make sure we test the preprocessing to replace what is used
 | 
			
		||||
        model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
 | 
			
		||||
 | 
			
		||||
        prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
 | 
			
		||||
        image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
 | 
			
		||||
        raw_image = Image.open(requests.get(image_file, stream=True).raw)
 | 
			
		||||
        inputs = self.processor(prompt, raw_image, return_tensors="pt")
 | 
			
		||||
 | 
			
		||||
        EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]])  # fmt: skip
 | 
			
		||||
        self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
 | 
			
		||||
 | 
			
		||||
        output = model.generate(**inputs, max_new_tokens=20)
 | 
			
		||||
        EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly,"  # fmt: skip
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            self.processor.decode(output[0], skip_special_tokens=True),
 | 
			
		||||
            EXPECTED_DECODED_TEXT,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -171,7 +171,7 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
 | 
			
		||||
            input_ids[0].tolist(),
 | 
			
		||||
            # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
 | 
			
		||||
            [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
 | 
			
		||||
        )
 | 
			
		||||
                    )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
 | 
			
		||||
        # Test passing in a url
 | 
			
		||||
@ -246,6 +246,25 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
 | 
			
		||||
    def test_processor_returns_full_length_batches(self):
 | 
			
		||||
        # to avoid https://github.com/huggingface/transformers/issues/34204
 | 
			
		||||
        processor = self.processor_class.from_pretrained(self.tmpdirname)
 | 
			
		||||
        prompt_string = [
 | 
			
		||||
            "USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
 | 
			
		||||
        ] * 5
 | 
			
		||||
        processor.tokenizer.pad_token = "</s>"
 | 
			
		||||
        image_inputs = [self.image_0] * 5
 | 
			
		||||
 | 
			
		||||
        # Make small for checking image token expansion
 | 
			
		||||
        processor.image_processor.size = {"longest_edge": 30}
 | 
			
		||||
        processor.image_processor.patch_size = {"height": 2, "width": 2}
 | 
			
		||||
 | 
			
		||||
        # Test passing in an image
 | 
			
		||||
        inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
 | 
			
		||||
        self.assertIn("input_ids", inputs_image)
 | 
			
		||||
        self.assertTrue(len(inputs_image["input_ids"]) == 5)
 | 
			
		||||
        self.assertTrue(len(inputs_image["pixel_values"]) == 5)
 | 
			
		||||
 | 
			
		||||
    # Override as PixtralProcessor needs nested images to work properly with batched inputs
 | 
			
		||||
    @require_vision
 | 
			
		||||
    def prepare_image_inputs(self, batch_size: Optional[int] = None):
 | 
			
		||||
 | 
			
		||||
@ -368,7 +368,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
 | 
			
		||||
            config_and_inputs[0].position_embedding_type = type
 | 
			
		||||
            self.model_tester.create_and_check_model(*config_and_inputs)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(reason="PR #34283 made changes to the forward function.")
 | 
			
		||||
    def test_torch_fx_output_loss(self):
 | 
			
		||||
        super().test_torch_fx_output_loss()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -391,7 +391,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
 | 
			
		||||
            config_and_inputs[0].position_embedding_type = type
 | 
			
		||||
            self.model_tester.create_and_check_model(*config_and_inputs)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(reason="PR #34283 made changes to the forward function.")
 | 
			
		||||
    def test_torch_fx_output_loss(self):
 | 
			
		||||
        super().test_torch_fx_output_loss()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,7 @@ class Qwen2VLVisionText2TextModelTester:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        parent,
 | 
			
		||||
        batch_size=2,
 | 
			
		||||
        batch_size=3,
 | 
			
		||||
        seq_length=7,
 | 
			
		||||
        num_channels=3,
 | 
			
		||||
        ignore_index=-100,
 | 
			
		||||
@ -245,6 +245,40 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
 | 
			
		||||
                        msg=f"Parameter {name} of model {model_class} seems not properly initialized",
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    def test_mismatching_num_image_tokens(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that VLMs through an error with explicit message saying what is wrong
 | 
			
		||||
        when number of images don't match number of image tokens in the text.
 | 
			
		||||
        Also we need to test multi-image cases when one prompr has multiple image tokens.
 | 
			
		||||
        """
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_model_classes:
 | 
			
		||||
            model = model_class(config).to(torch_device)
 | 
			
		||||
            _ = model(**input_dict)  # successfull forward with no modifications
 | 
			
		||||
 | 
			
		||||
            # remove one image but leave the image token in text
 | 
			
		||||
            patch_size = config.vision_config.patch_size
 | 
			
		||||
            one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
 | 
			
		||||
            input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
 | 
			
		||||
            input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**input_dict)
 | 
			
		||||
 | 
			
		||||
            # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
 | 
			
		||||
            input_ids = input_dict["input_ids"][:1]
 | 
			
		||||
            pixel_values = input_dict["pixel_values"][:one_img_length]
 | 
			
		||||
            image_grid_thw = input_dict["image_grid_thw"][:1]
 | 
			
		||||
            input_ids = torch.cat([input_ids, input_ids], dim=0)
 | 
			
		||||
 | 
			
		||||
            # one image and two image tokens raise an error
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
 | 
			
		||||
 | 
			
		||||
            # two images and two image tokens don't raise an error
 | 
			
		||||
            pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
 | 
			
		||||
            image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
 | 
			
		||||
            _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -116,9 +116,9 @@ class VideoLlavaVisionText2TextModelTester:
 | 
			
		||||
        self.batch_size = 5
 | 
			
		||||
        self.num_channels = 3
 | 
			
		||||
        self.image_size = 224
 | 
			
		||||
        self.encoder_seq_length = 64
 | 
			
		||||
        self.encoder_seq_length = 246
 | 
			
		||||
        self.num_image_tokens = 25
 | 
			
		||||
        self.num_video_tokens = 26
 | 
			
		||||
        self.num_video_tokens = 26 * self.num_frames
 | 
			
		||||
        self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
 | 
			
		||||
 | 
			
		||||
    def get_config(self):
 | 
			
		||||
@ -262,7 +262,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
 | 
			
		||||
            # if we remove some images from inputs leaving only one
 | 
			
		||||
            # image number mismatch error should raise
 | 
			
		||||
            inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
 | 
			
		||||
            with self.assertRaises(RuntimeError):
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**inputs)
 | 
			
		||||
 | 
			
		||||
    def test_video_only_input(self):
 | 
			
		||||
@ -396,6 +396,35 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
 | 
			
		||||
                out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
 | 
			
		||||
            self.assertTrue(torch.allclose(out_embeds, out_ids))
 | 
			
		||||
 | 
			
		||||
    def test_mismatching_num_image_tokens(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that VLMs through an error with explicit message saying what is wrong
 | 
			
		||||
        when number of images don't match number of image tokens in the text.
 | 
			
		||||
        Also we need to test multi-image cases when one prompr has multiple image tokens.
 | 
			
		||||
        """
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_model_classes:
 | 
			
		||||
            model = model_class(config).to(torch_device)
 | 
			
		||||
            _ = model(**input_dict)  # successfull forward with no modifications
 | 
			
		||||
 | 
			
		||||
            # remove one image but leave the image token in text
 | 
			
		||||
            input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**input_dict)
 | 
			
		||||
 | 
			
		||||
            # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
 | 
			
		||||
            input_ids = input_dict["input_ids"][:1]
 | 
			
		||||
            pixel_values = input_dict["pixel_values_images"][:1]
 | 
			
		||||
            input_ids = torch.cat([input_ids, input_ids], dim=0)
 | 
			
		||||
 | 
			
		||||
            # one image and two image tokens raise an error
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(input_ids=input_ids, pixel_values_images=pixel_values)
 | 
			
		||||
 | 
			
		||||
            # two images and two image tokens don't raise an error
 | 
			
		||||
            pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
 | 
			
		||||
            _ = model(input_ids=input_ids, pixel_values_images=pixel_values)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@require_torch
 | 
			
		||||
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
@ -217,6 +217,36 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
 | 
			
		||||
                out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
 | 
			
		||||
            self.assertTrue(torch.allclose(out_embeds, out_ids))
 | 
			
		||||
 | 
			
		||||
    # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
 | 
			
		||||
    def test_mismatching_num_image_tokens(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that VLMs through an error with explicit message saying what is wrong
 | 
			
		||||
        when number of images don't match number of image tokens in the text.
 | 
			
		||||
        Also we need to test multi-image cases when one prompr has multiple image tokens.
 | 
			
		||||
        """
 | 
			
		||||
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
 | 
			
		||||
        for model_class in self.all_model_classes:
 | 
			
		||||
            model = model_class(config).to(torch_device)
 | 
			
		||||
            _ = model(**input_dict)  # successfull forward with no modifications
 | 
			
		||||
 | 
			
		||||
            # remove one image but leave the image token in text
 | 
			
		||||
            input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(**input_dict)
 | 
			
		||||
 | 
			
		||||
            # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
 | 
			
		||||
            input_ids = input_dict["input_ids"][:1]
 | 
			
		||||
            pixel_values = input_dict["pixel_values"][:1]
 | 
			
		||||
            input_ids = torch.cat([input_ids, input_ids], dim=0)
 | 
			
		||||
 | 
			
		||||
            # one image and two image tokens raise an error
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                _ = model(input_ids=input_ids, pixel_values=pixel_values)
 | 
			
		||||
 | 
			
		||||
            # two images and two image tokens don't raise an error
 | 
			
		||||
            pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
 | 
			
		||||
            _ = model(input_ids=input_ids, pixel_values=pixel_values)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -208,6 +208,26 @@ class TorchAoTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
 | 
			
		||||
 | 
			
		||||
    def test_int8_dynamic_activation_int8_weight_quant(self):
 | 
			
		||||
        """
 | 
			
		||||
        Simple LLM model testing int8_dynamic_activation_int8_weight
 | 
			
		||||
        """
 | 
			
		||||
        quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
 | 
			
		||||
 | 
			
		||||
        # Note: we quantize the bfloat16 model on the fly to int4
 | 
			
		||||
        quantized_model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
            self.model_name,
 | 
			
		||||
            device_map=torch_device,
 | 
			
		||||
            quantization_config=quant_config,
 | 
			
		||||
        )
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 | 
			
		||||
 | 
			
		||||
        input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
 | 
			
		||||
 | 
			
		||||
        output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
 | 
			
		||||
        EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
 | 
			
		||||
        self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 | 
			
		||||
@ -272,6 +272,19 @@ class RepeatDataset:
 | 
			
		||||
        return {"input_ids": self.x, "labels": self.x}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SequenceClassificationDataset:
 | 
			
		||||
    def __init__(self, length=64, vocab_size=100, num_labels=5):
 | 
			
		||||
        self.length = length
 | 
			
		||||
        self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
 | 
			
		||||
        self.labels = torch.randint(0, num_labels, (length,)).tolist()
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self.length
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, i):
 | 
			
		||||
        return {"input_ids": self.sequences[i], "label": self.labels[i]}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DynamicShapesDataset:
 | 
			
		||||
    def __init__(self, length=64, seed=42, batch_size=8):
 | 
			
		||||
        self.length = length
 | 
			
		||||
@ -1144,6 +1157,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
 | 
			
		||||
            train_output = trainer.train()
 | 
			
		||||
            self.assertEqual(train_output.global_step, 10)
 | 
			
		||||
 | 
			
		||||
    def test_torch_compile_loss_func_compatibility(self):
 | 
			
		||||
        config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
 | 
			
		||||
        tiny_llama = LlamaForCausalLM(config)
 | 
			
		||||
 | 
			
		||||
        x = torch.randint(0, 100, (128,))
 | 
			
		||||
        train_dataset = RepeatDataset(x)
 | 
			
		||||
 | 
			
		||||
        with tempfile.TemporaryDirectory() as tmp_dir:
 | 
			
		||||
            args = TrainingArguments(
 | 
			
		||||
                tmp_dir,
 | 
			
		||||
                per_device_train_batch_size=2,
 | 
			
		||||
                torch_compile=True,
 | 
			
		||||
                max_steps=1,  # compile happens on the first step
 | 
			
		||||
            )
 | 
			
		||||
            trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset)  # noqa
 | 
			
		||||
            trainer.train()
 | 
			
		||||
 | 
			
		||||
    @require_peft
 | 
			
		||||
    @require_bitsandbytes
 | 
			
		||||
    def test_bnb_compile(self):
 | 
			
		||||
@ -3676,9 +3706,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
 | 
			
		||||
            self.assertEqual(trainer.accelerator.even_batches, False)
 | 
			
		||||
            self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
 | 
			
		||||
 | 
			
		||||
            if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
 | 
			
		||||
                self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
 | 
			
		||||
 | 
			
		||||
    def test_accelerator_config_from_yaml(self):
 | 
			
		||||
        # Checks that accelerator kwargs can be passed through
 | 
			
		||||
        # and the accelerator is initialized respectively
 | 
			
		||||
@ -3691,8 +3718,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
 | 
			
		||||
                    "even_batches": False,
 | 
			
		||||
                    "use_seedable_sampler": False,
 | 
			
		||||
                }
 | 
			
		||||
                if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
 | 
			
		||||
                    accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
 | 
			
		||||
                json.dump(accelerator_config, f)
 | 
			
		||||
            config = RegressionModelConfig(a=1.5, b=2.5)
 | 
			
		||||
            model = RegressionPreTrainedModel(config)
 | 
			
		||||
@ -3706,9 +3731,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
 | 
			
		||||
            self.assertEqual(trainer.accelerator.even_batches, False)
 | 
			
		||||
            self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
 | 
			
		||||
 | 
			
		||||
            if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
 | 
			
		||||
                self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
 | 
			
		||||
 | 
			
		||||
    def test_accelerator_config_from_dataclass(self):
 | 
			
		||||
        # Checks that accelerator kwargs can be passed through
 | 
			
		||||
        # and the accelerator is initialized respectively
 | 
			
		||||
@ -3754,10 +3776,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
 | 
			
		||||
        with tempfile.TemporaryDirectory() as tmp_dir:
 | 
			
		||||
            args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
 | 
			
		||||
            trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
 | 
			
		||||
            self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
 | 
			
		||||
            self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
 | 
			
		||||
            self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
 | 
			
		||||
            self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
 | 
			
		||||
            self.assertEqual(trainer.args.gradient_accumulation_steps, 10)
 | 
			
		||||
 | 
			
		||||
    def test_accelerator_config_from_partial(self):
 | 
			
		||||
        # Checks that accelerator kwargs can be passed through
 | 
			
		||||
 | 
			
		||||
@ -1710,7 +1710,12 @@ class ModelUtilsTest(TestCasePlus):
 | 
			
		||||
                torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        # We can match against an tensor of integers
 | 
			
		||||
        # We can match against an 0D tensor
 | 
			
		||||
        random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
 | 
			
		||||
        )
 | 
			
		||||
        # We can match against an 1D tensor (with many items)
 | 
			
		||||
        random_test_tensor = torch.randint(0, 100, (10,))
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user