mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-10-26 14:06:45 +08:00 
			
		
		
		
	Compare commits
	
		
			18 Commits
		
	
	
		
			fix-datase
			...
			v4.45.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 53fad641cf | |||
| 2fd49d2b28 | |||
| 5df4ca826d | |||
| 277ed58f06 | |||
| b1c237fc4e | |||
| ae5f4916de | |||
| 6ea04aaad8 | |||
| be968434fd | |||
| 333ec0a523 | |||
| 3576fec8a3 | |||
| f0686f567a | |||
| 27f03e0a7b | |||
| e71a01a104 | |||
| 0317895840 | |||
| 4ea1c43a10 | |||
| 289edd9e8c | |||
| c64be318fc | |||
| 2ef31dec16 | 
| @ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| Array = Any | ||||
| Dataset = datasets.arrow_dataset.Dataset | ||||
|  | ||||
| @ -60,7 +60,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risk. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| Array = Any | ||||
| Dataset = datasets.arrow_dataset.Dataset | ||||
|  | ||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -45,7 +45,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -54,7 +54,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
|  | ||||
| @ -43,7 +43,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used. | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used. | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -46,7 +46,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -55,7 +55,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
|  | ||||
| @ -58,7 +58,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -60,7 +60,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
|  | ||||
| @ -54,7 +54,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||
|  | ||||
| @ -47,7 +47,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_ | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_ | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
| # You should update this to your particular problem to have better documentation of `model_type` | ||||
|  | ||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logging.basicConfig(level=logging.INFO) | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -46,7 +46,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
|  | ||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -53,7 +53,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||
|  | ||||
| @ -47,7 +47,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
|  | ||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||
|  | ||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | ||||
|  | ||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version( | ||||
|     "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" | ||||
|  | ||||
| @ -55,7 +55,7 @@ from transformers.utils.versions import require_version | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_ | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError): | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @ -53,7 +53,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
| # region Checking dependencies | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||
|  | ||||
|  | ||||
| @ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry | ||||
|  | ||||
|  | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| task_to_keys = { | ||||
|     "cola": ("sentence", None), | ||||
|  | ||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | ||||
|  | ||||
| # region Dependencies and constants | ||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||
| check_min_version("4.45.0.dev0") | ||||
| check_min_version("4.45.0") | ||||
|  | ||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||
|  | ||||
|  | ||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @ -436,7 +436,7 @@ install_requires = [ | ||||
|  | ||||
| setup( | ||||
|     name="transformers", | ||||
|     version="4.45.0.dev0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) | ||||
|     version="4.45.2",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) | ||||
|     author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", | ||||
|     author_email="transformers@huggingface.co", | ||||
|     description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", | ||||
|  | ||||
| @ -18,7 +18,7 @@ | ||||
| # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names | ||||
| # in the namespace without actually importing anything (and especially none of the backends). | ||||
|  | ||||
| __version__ = "4.45.0.dev0" | ||||
| __version__ = "4.45.2" | ||||
|  | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
|  | ||||
| @ -10,6 +10,7 @@ from packaging import version | ||||
|  | ||||
| from .configuration_utils import PretrainedConfig | ||||
| from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging | ||||
| from .utils.deprecation import deprecate_kwarg | ||||
|  | ||||
|  | ||||
| if is_quanto_available(): | ||||
| @ -17,6 +18,7 @@ if is_quanto_available(): | ||||
|     if quanto_version >= version.parse("0.2.0"): | ||||
|         from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4 | ||||
|  | ||||
|  | ||||
| if is_hqq_available(): | ||||
|     from hqq.core.quantize import Quantizer as HQQQuantizer | ||||
|  | ||||
| @ -360,15 +362,12 @@ class DynamicCache(Cache): | ||||
|         ``` | ||||
|     """ | ||||
|  | ||||
|     @deprecate_kwarg("num_hidden_layers", version="4.47.0") | ||||
|     def __init__(self, num_hidden_layers: Optional[int] = None) -> None: | ||||
|         super().__init__() | ||||
|         if num_hidden_layers is None: | ||||
|             self.key_cache: List[torch.Tensor] = [] | ||||
|             self.value_cache: List[torch.Tensor] = [] | ||||
|         else: | ||||
|             self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)] | ||||
|             self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)] | ||||
|         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen | ||||
|         self.key_cache: List[torch.Tensor] = [] | ||||
|         self.value_cache: List[torch.Tensor] = [] | ||||
|  | ||||
|     def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | ||||
|         """ | ||||
| @ -424,11 +423,13 @@ class DynamicCache(Cache): | ||||
|  | ||||
|         # Update the cache | ||||
|         if len(self.key_cache) <= layer_idx: | ||||
|             # There may be skipped layers, fill them with empty lists | ||||
|             for _ in range(len(self.key_cache), layer_idx): | ||||
|                 self.key_cache.append([]) | ||||
|                 self.value_cache.append([]) | ||||
|             self.key_cache.append(key_states) | ||||
|             self.value_cache.append(value_states) | ||||
|         # content on layer cache can be a tensor and checking not tensor causes errors | ||||
|         # so we explicitly check for the empty list | ||||
|         elif self.key_cache[layer_idx] == []: | ||||
|         elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors | ||||
|             self.key_cache[layer_idx] = key_states | ||||
|             self.value_cache[layer_idx] = value_states | ||||
|         else: | ||||
| @ -440,9 +441,13 @@ class DynamicCache(Cache): | ||||
|     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||||
|         """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||||
|         # TODO: deprecate this function in favor of `cache_position` | ||||
|         if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []): | ||||
|             return 0 | ||||
|         return self.key_cache[layer_idx].shape[-2] | ||||
|         is_empty_layer = ( | ||||
|             len(self.key_cache) == 0  # no cache in any layer | ||||
|             or len(self.key_cache) <= layer_idx  # skipped `layer_idx` and hasn't run a layer with cache after it | ||||
|             or len(self.key_cache[layer_idx]) == 0  # the layer has no cache | ||||
|         ) | ||||
|         layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 | ||||
|         return layer_seq_length | ||||
|  | ||||
|     def get_max_length(self) -> Optional[int]: | ||||
|         """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" | ||||
| @ -457,12 +462,13 @@ class DynamicCache(Cache): | ||||
|         return legacy_cache | ||||
|  | ||||
|     @classmethod | ||||
|     @deprecate_kwarg("num_hidden_layers", version="4.47.0") | ||||
|     def from_legacy_cache( | ||||
|         cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None | ||||
|     ) -> "DynamicCache": | ||||
|         """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for | ||||
|         backward compatibility.""" | ||||
|         cache = cls(num_hidden_layers) | ||||
|         cache = cls() | ||||
|         if past_key_values is not None: | ||||
|             for layer_idx in range(len(past_key_values)): | ||||
|                 key_states, value_states = past_key_values[layer_idx] | ||||
| @ -485,12 +491,15 @@ class DynamicCache(Cache): | ||||
|                 self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] | ||||
|                 self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] | ||||
|  | ||||
|     def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]: | ||||
|     @deprecate_kwarg("num_hidden_layers", version="4.47.0") | ||||
|     def batch_split( | ||||
|         self, full_batch_size: int, split_size: int, num_hidden_layers: int = None | ||||
|     ) -> List["DynamicCache"]: | ||||
|         """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by | ||||
|         `_split_model_inputs()` in `generation.utils`""" | ||||
|         out = [] | ||||
|         for i in range(0, full_batch_size, split_size): | ||||
|             current_split = DynamicCache(num_hidden_layers) | ||||
|             current_split = DynamicCache() | ||||
|             current_split._seen_tokens = self._seen_tokens | ||||
|             current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] | ||||
|             current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] | ||||
| @ -498,10 +507,11 @@ class DynamicCache(Cache): | ||||
|         return out | ||||
|  | ||||
|     @classmethod | ||||
|     def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache": | ||||
|     @deprecate_kwarg("num_hidden_layers", version="4.47.0") | ||||
|     def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache": | ||||
|         """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in | ||||
|         `generation.utils`""" | ||||
|         cache = cls(num_hidden_layers) | ||||
|         cache = cls() | ||||
|         for idx in range(len(splits[0])): | ||||
|             key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] | ||||
|             value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] | ||||
| @ -617,7 +627,9 @@ class OffloadedCache(DynamicCache): | ||||
|             self._seen_tokens += key_states.shape[-2] | ||||
|  | ||||
|         # Update the cache | ||||
|         if len(self.key_cache) <= layer_idx: | ||||
|         if len(self.key_cache) < layer_idx: | ||||
|             raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") | ||||
|         elif len(self.key_cache) == layer_idx: | ||||
|             self.key_cache.append(key_states) | ||||
|             self.value_cache.append(value_states) | ||||
|             self.original_device.append(key_states.device) | ||||
| @ -676,7 +688,9 @@ class QuantizedCache(DynamicCache): | ||||
|         if layer_idx == 0: | ||||
|             self._seen_tokens += key_states.shape[-2] | ||||
|  | ||||
|         if len(self.key_cache) <= layer_idx: | ||||
|         if len(self.key_cache) < layer_idx: | ||||
|             raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") | ||||
|         elif len(self.key_cache) == layer_idx: | ||||
|             self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) | ||||
|             self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) | ||||
|             self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) | ||||
| @ -1408,12 +1422,12 @@ class EncoderDecoderCache(Cache): | ||||
|  | ||||
|     @classmethod | ||||
|     def from_legacy_cache( | ||||
|         cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None | ||||
|         cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||||
|     ) -> "EncoderDecoderCache": | ||||
|         """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" | ||||
|         cache = cls( | ||||
|             self_attention_cache=DynamicCache(num_hidden_layers), | ||||
|             cross_attention_cache=DynamicCache(num_hidden_layers), | ||||
|             self_attention_cache=DynamicCache(), | ||||
|             cross_attention_cache=DynamicCache(), | ||||
|         ) | ||||
|         if past_key_values is not None: | ||||
|             for layer_idx in range(len(past_key_values)): | ||||
| @ -1471,14 +1485,12 @@ class EncoderDecoderCache(Cache): | ||||
|         self.check_dynamic_cache(self.crop.__name__) | ||||
|         self.self_attention_cache.crop(maximum_length) | ||||
|  | ||||
|     def batch_split( | ||||
|         self, full_batch_size: int, split_size: int, num_hidden_layers: int | ||||
|     ) -> "List[EncoderDecoderCache]": | ||||
|     def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": | ||||
|         """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by | ||||
|         `_split_model_inputs()` in `generation.utils`""" | ||||
|         self.check_dynamic_cache(self.batch_split.__name__) | ||||
|         self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers) | ||||
|         cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers) | ||||
|         self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) | ||||
|         cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) | ||||
|  | ||||
|         out = [] | ||||
|         for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): | ||||
| @ -1486,11 +1498,11 @@ class EncoderDecoderCache(Cache): | ||||
|         return out | ||||
|  | ||||
|     @classmethod | ||||
|     def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache": | ||||
|     def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": | ||||
|         """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in | ||||
|         `generation.utils`""" | ||||
|         self_attention_cache = DynamicCache(num_hidden_layers) | ||||
|         cross_attention_cache = DynamicCache(num_hidden_layers) | ||||
|         self_attention_cache = DynamicCache() | ||||
|         cross_attention_cache = DynamicCache() | ||||
|         for idx in range(len(splits[0])): | ||||
|             layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) | ||||
|             layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) | ||||
|  | ||||
| @ -380,11 +380,14 @@ class PretrainedConfig(PushToHubMixin): | ||||
|  | ||||
|         non_default_generation_parameters = self._get_non_default_generation_parameters() | ||||
|         if len(non_default_generation_parameters) > 0: | ||||
|             raise ValueError( | ||||
|             # TODO (joao): this should be an exception if the user has modified the loaded config. See #33886 | ||||
|             warnings.warn( | ||||
|                 "Some non-default generation parameters are set in the model config. These should go into either a) " | ||||
|                 "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file " | ||||
|                 "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) " | ||||
|                 f"\nNon-default generation parameters: {str(non_default_generation_parameters)}" | ||||
|                 "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)." | ||||
|                 "This warning will become an exception in the future." | ||||
|                 f"\nNon-default generation parameters: {str(non_default_generation_parameters)}", | ||||
|                 UserWarning, | ||||
|             ) | ||||
|  | ||||
|         os.makedirs(save_directory, exist_ok=True) | ||||
|  | ||||
| @ -1602,11 +1602,10 @@ class GenerationMixin: | ||||
|         # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that | ||||
|         # keeps copying the cache thus using much more memory | ||||
|         else: | ||||
|             num_hidden_layers = self.config.get_text_config().num_hidden_layers | ||||
|             model_kwargs[cache_name] = ( | ||||
|                 DynamicCache(num_hidden_layers) | ||||
|                 DynamicCache() | ||||
|                 if not requires_cross_attention_cache | ||||
|                 else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers)) | ||||
|                 else EncoderDecoderCache(DynamicCache(), DynamicCache()) | ||||
|             ) | ||||
|  | ||||
|     def _supports_num_logits_to_keep(self) -> bool: | ||||
|  | ||||
| @ -360,13 +360,23 @@ ROPE_INIT_FUNCTIONS = { | ||||
| } | ||||
|  | ||||
|  | ||||
| def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): | ||||
| def _check_received_keys( | ||||
|     rope_type: str, | ||||
|     received_keys: set, | ||||
|     required_keys: set, | ||||
|     optional_keys: Optional[set] = None, | ||||
|     ignore_keys: Optional[set] = None, | ||||
| ): | ||||
|     """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" | ||||
|     # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present | ||||
|     if "type" in received_keys: | ||||
|         received_keys -= {"type"} | ||||
|         required_keys.add("rope_type") | ||||
|  | ||||
|     # Some models need to store model-specific keys, and we don't want to throw warning at them | ||||
|     if ignore_keys is not None: | ||||
|         received_keys -= ignore_keys | ||||
|  | ||||
|     missing_keys = required_keys - received_keys | ||||
|     if missing_keys: | ||||
|         raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") | ||||
| @ -379,47 +389,47 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, | ||||
|         logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") | ||||
|  | ||||
|  | ||||
| def _validate_default_rope_parameters(config: PretrainedConfig): | ||||
| def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||||
|     rope_scaling = config.rope_scaling | ||||
|     rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type" | ||||
|     required_keys = {"rope_type"} | ||||
|     received_keys = set(rope_scaling.keys()) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) | ||||
|  | ||||
|  | ||||
| def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): | ||||
| def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||||
|     rope_scaling = config.rope_scaling | ||||
|     rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type" | ||||
|     required_keys = {"rope_type", "factor"} | ||||
|     received_keys = set(rope_scaling.keys()) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) | ||||
|  | ||||
|     factor = rope_scaling["factor"] | ||||
|     if factor is None or not isinstance(factor, float) or factor < 1.0: | ||||
|         logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") | ||||
|  | ||||
|  | ||||
| def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): | ||||
| def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||||
|     rope_scaling = config.rope_scaling | ||||
|     rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type" | ||||
|     required_keys = {"rope_type", "factor"} | ||||
|     # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` | ||||
|     optional_keys = {"original_max_position_embeddings"} | ||||
|     received_keys = set(rope_scaling.keys()) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, optional_keys) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) | ||||
|  | ||||
|     factor = rope_scaling["factor"] | ||||
|     if factor is None or not isinstance(factor, float) or factor < 1.0: | ||||
|         logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") | ||||
|  | ||||
|  | ||||
| def _validate_yarn_parameters(config: PretrainedConfig): | ||||
| def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||||
|     rope_scaling = config.rope_scaling | ||||
|     rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type" | ||||
|     required_keys = {"rope_type", "factor"} | ||||
|     optional_keys = {"attention_factor", "beta_fast", "beta_slow"} | ||||
|     received_keys = set(rope_scaling.keys()) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, optional_keys) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) | ||||
|  | ||||
|     factor = rope_scaling["factor"] | ||||
|     if factor is None or not isinstance(factor, float) or factor < 1.0: | ||||
| @ -444,14 +454,14 @@ def _validate_yarn_parameters(config: PretrainedConfig): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def _validate_longrope_parameters(config: PretrainedConfig): | ||||
| def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||||
|     rope_scaling = config.rope_scaling | ||||
|     rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type" | ||||
|     required_keys = {"rope_type", "short_factor", "long_factor"} | ||||
|     # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` | ||||
|     optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} | ||||
|     received_keys = set(rope_scaling.keys()) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, optional_keys) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) | ||||
|  | ||||
|     partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 | ||||
|     head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | ||||
| @ -494,12 +504,12 @@ def _validate_longrope_parameters(config: PretrainedConfig): | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| def _validate_llama3_parameters(config: PretrainedConfig): | ||||
| def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||||
|     rope_scaling = config.rope_scaling | ||||
|     rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type" | ||||
|     required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} | ||||
|     received_keys = set(rope_scaling.keys()) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys) | ||||
|     _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) | ||||
|  | ||||
|     factor = rope_scaling["factor"] | ||||
|     if factor is None or not isinstance(factor, float) or factor < 1.0: | ||||
| @ -541,7 +551,7 @@ ROPE_VALIDATION_FUNCTIONS = { | ||||
| } | ||||
|  | ||||
|  | ||||
| def rope_config_validation(config: PretrainedConfig): | ||||
| def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||||
|     """ | ||||
|     Validate the RoPE config arguments, given a `PretrainedConfig` object | ||||
|     """ | ||||
| @ -553,7 +563,7 @@ def rope_config_validation(config: PretrainedConfig): | ||||
|     rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) | ||||
|     validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) | ||||
|     if validation_fn is not None: | ||||
|         validation_fn(config) | ||||
|         validation_fn(config, ignore_keys=ignore_keys) | ||||
|     else: | ||||
|         logger.warning( | ||||
|             f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" | ||||
|  | ||||
| @ -1645,6 +1645,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | ||||
|         # Model class overwrites `generate` (e.g. time series models) -> can generate | ||||
|         if str(cls.__name__) in str(cls.generate): | ||||
|             return True | ||||
|         # The class inherits from a class that can generate (recursive check) -> can generate | ||||
|         for base in cls.__bases__: | ||||
|             if not hasattr(base, "can_generate"): | ||||
|                 continue | ||||
|             if "PreTrainedModel" not in str(base) and base.can_generate(): | ||||
|                 return True | ||||
|         # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this | ||||
|         # was how we detected whether a model could generate. | ||||
|         if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): | ||||
|  | ||||
| @ -88,6 +88,9 @@ class BertTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original BERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -105,6 +108,7 @@ class BertTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -136,6 +140,7 @@ class BertTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -91,6 +91,9 @@ class ConvBertTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original ConvBERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -108,6 +111,7 @@ class ConvBertTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -139,6 +143,7 @@ class ConvBertTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -90,6 +90,9 @@ class DistilBertTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original BERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -108,6 +111,7 @@ class DistilBertTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -138,6 +142,7 @@ class DistilBertTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -90,6 +90,9 @@ class ElectraTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original Electra). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -107,6 +110,7 @@ class ElectraTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -138,6 +142,7 @@ class ElectraTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -107,6 +107,9 @@ class FunnelTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original BERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -127,6 +130,7 @@ class FunnelTokenizer(PreTrainedTokenizer): | ||||
|         eos_token="</s>", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -159,6 +163,7 @@ class FunnelTokenizer(PreTrainedTokenizer): | ||||
|             eos_token=eos_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -1348,17 +1348,18 @@ class Idefics2Model(Idefics2PreTrainedModel): | ||||
|         past_seen_tokens = 0 | ||||
|         # kept for BC (non `Cache` `past_key_values` inputs) | ||||
|         return_legacy_cache = False | ||||
|         if use_cache and not isinstance(past_key_values, Cache): | ||||
|             return_legacy_cache = True | ||||
|             if past_key_values is None: | ||||
|                 past_key_values = DynamicCache() | ||||
|             else: | ||||
|                 past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||||
|                 logger.warning_once( | ||||
|                     "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " | ||||
|                     "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " | ||||
|                     "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" | ||||
|                 ) | ||||
|         if use_cache: | ||||
|             if not isinstance(past_key_values, Cache): | ||||
|                 return_legacy_cache = True | ||||
|                 if past_key_values is None: | ||||
|                     past_key_values = DynamicCache() | ||||
|                 else: | ||||
|                     past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||||
|                     logger.warning_once( | ||||
|                         "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " | ||||
|                         "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " | ||||
|                         "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" | ||||
|                     ) | ||||
|             past_seen_tokens = past_key_values.get_seq_length() | ||||
|  | ||||
|         if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: | ||||
|  | ||||
| @ -91,6 +91,9 @@ class LayoutLMTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original LayoutLM). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -108,6 +111,7 @@ class LayoutLMTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -139,6 +143,7 @@ class LayoutLMTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -90,6 +90,9 @@ class LxmertTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original Lxmert). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -107,6 +110,7 @@ class LxmertTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -138,6 +142,7 @@ class LxmertTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -12,11 +12,9 @@ | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
| Processor class for Mllama. | ||||
| """ | ||||
|  | ||||
| from statistics import mean | ||||
| """Processor class for Mllama.""" | ||||
|  | ||||
| from typing import List, Optional, Union | ||||
|  | ||||
| import numpy as np | ||||
| @ -296,25 +294,27 @@ class MllamaProcessor(ProcessorMixin): | ||||
|             encoding = self.tokenizer(text, **text_kwargs) | ||||
|             data.update(encoding) | ||||
|  | ||||
|         n_images_in_images = [0] | ||||
|         if images is not None: | ||||
|             images = make_list_of_images(images) | ||||
|             n_images_in_images = [len(sample) for sample in images] | ||||
|  | ||||
|             if text is not None: | ||||
|                 if ( | ||||
|                     not all(batch_img_per_prompt == n_images_in_images for batch_img_per_prompt in n_images_in_text) | ||||
|                     and len(text) > 1 | ||||
|                 ): | ||||
|         if text is not None: | ||||
|             if any(batch_img == 0 for batch_img in n_images_in_text) and not all( | ||||
|                 batch_img == 0 for batch_img in n_images_in_text | ||||
|             ): | ||||
|                 raise ValueError( | ||||
|                     "If a batch of text is provided, there should be either no images or at least one image per sample" | ||||
|                 ) | ||||
|             if sum(n_images_in_images) != sum(n_images_in_text): | ||||
|                 if images is None: | ||||
|                     raise ValueError("No image were provided, but there are image tokens in the prompt") | ||||
|                 else: | ||||
|                     raise ValueError( | ||||
|                         f"The number of images in each batch {n_images_in_text} should be the same  {n_images_in_images} should be the same. Yes, the model does not \ | ||||
|                         support having a different number of images per batch." | ||||
|                     ) | ||||
|                 if int(mean(n_images_in_text)) != int(mean(n_images_in_images)): | ||||
|                     raise ValueError( | ||||
|                         f"The number of images in the text ({n_images_in_text}) should be the same as in the number of provided images ({n_images_in_images}) \ | ||||
|                         should be the same." | ||||
|                         f"The number of image token ({sum(n_images_in_images)}) should be the same as in the number of provided images ({sum(n_images_in_images)})" | ||||
|                     ) | ||||
|  | ||||
|         if images is not None: | ||||
|             image_features = self.image_processor(images, **images_kwargs) | ||||
|             num_tiles = image_features.pop("num_tiles") | ||||
|             data.update(image_features) | ||||
|  | ||||
| @ -92,6 +92,9 @@ class MobileBertTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original MobileBERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -109,6 +112,7 @@ class MobileBertTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -140,6 +144,7 @@ class MobileBertTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -108,6 +108,9 @@ class MPNetTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original BERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -128,6 +131,7 @@ class MPNetTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="<mask>", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token | ||||
| @ -170,6 +174,7 @@ class MPNetTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -130,6 +130,7 @@ class PLBartTokenizer(PreTrainedTokenizer): | ||||
|         tgt_lang=None, | ||||
|         sp_model_kwargs: Optional[Dict[str, Any]] = None, | ||||
|         additional_special_tokens=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         # Mask token behave like a normal word, i.e. include the space before it | ||||
| @ -200,6 +201,7 @@ class PLBartTokenizer(PreTrainedTokenizer): | ||||
|             tgt_lang=tgt_lang, | ||||
|             additional_special_tokens=_additional_special_tokens, | ||||
|             sp_model_kwargs=self.sp_model_kwargs, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -308,6 +308,9 @@ class ProphetNetTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original BERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -330,6 +333,7 @@ class ProphetNetTokenizer(PreTrainedTokenizer): | ||||
|         mask_token: Optional[str] = "[MASK]", | ||||
|         tokenize_chinese_chars: Optional[bool] = True, | ||||
|         strip_accents: Optional[bool] = None, | ||||
|         clean_up_tokenization_spaces: bool = True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -360,6 +364,7 @@ class ProphetNetTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -235,11 +235,13 @@ class Qwen2VLConfig(PretrainedConfig): | ||||
|  | ||||
|         # Validate the correctness of rotary position embeddings parameters | ||||
|         # BC: if there is a 'type' field, move it to 'rope_type'. | ||||
|         # and change type from 'mrope' to 'default' | ||||
|         # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations | ||||
|         # one can set it to "linear"/"dynamic" etc. to have scaled RoPE | ||||
|         # TODO: @raushan update config in the hub | ||||
|         if self.rope_scaling is not None and "type" in self.rope_scaling: | ||||
|             if self.rope_scaling["type"] == "mrope": | ||||
|                 self.rope_scaling["type"] = "default" | ||||
|             self.rope_scaling["rope_type"] = self.rope_scaling["type"] | ||||
|         rope_config_validation(self) | ||||
|         rope_config_validation(self, ignore_keys={"mrope_section"}) | ||||
|  | ||||
|         super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) | ||||
|  | ||||
| @ -279,13 +279,13 @@ class SiglipVisionEmbeddings(nn.Module): | ||||
|         """ | ||||
|  | ||||
|         num_patches = embeddings.shape[1] | ||||
|         num_positions = self.position_embeddings.shape[1] | ||||
|         num_positions = self.position_embedding.weight.shape[0] | ||||
|  | ||||
|         # always interpolate when tracing to ensure the exported model works for dynamic input shapes | ||||
|         if not torch.jit.is_tracing() and num_patches == num_positions and height == width: | ||||
|             return self.position_embeddings | ||||
|             return self.position_embedding(self.position_ids) | ||||
|  | ||||
|         patch_pos_embed = self.position_embeddings | ||||
|         patch_pos_embed = self.position_embedding.weight.unsqueeze(0) | ||||
|  | ||||
|         dim = embeddings.shape[-1] | ||||
|  | ||||
|  | ||||
| @ -91,6 +91,9 @@ class SqueezeBertTokenizer(PreTrainedTokenizer): | ||||
|         strip_accents (`bool`, *optional*): | ||||
|             Whether or not to strip all accents. If this option is not specified, then it will be determined by the | ||||
|             value for `lowercase` (as in the original SqueezeBERT). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -108,6 +111,7 @@ class SqueezeBertTokenizer(PreTrainedTokenizer): | ||||
|         mask_token="[MASK]", | ||||
|         tokenize_chinese_chars=True, | ||||
|         strip_accents=None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not os.path.isfile(vocab_file): | ||||
| @ -139,6 +143,7 @@ class SqueezeBertTokenizer(PreTrainedTokenizer): | ||||
|             mask_token=mask_token, | ||||
|             tokenize_chinese_chars=tokenize_chinese_chars, | ||||
|             strip_accents=strip_accents, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -225,6 +225,9 @@ class TapasTokenizer(PreTrainedTokenizer): | ||||
|             Minimum length of each question in terms of tokens (will be skipped otherwise). | ||||
|         max_question_length (`int`, *optional*): | ||||
|             Maximum length of each question in terms of tokens (will be skipped otherwise). | ||||
|         clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like | ||||
|             extra spaces. | ||||
|     """ | ||||
|  | ||||
|     vocab_files_names = VOCAB_FILES_NAMES | ||||
| @ -252,6 +255,7 @@ class TapasTokenizer(PreTrainedTokenizer): | ||||
|         max_question_length=None, | ||||
|         model_max_length: int = 512, | ||||
|         additional_special_tokens: Optional[List[str]] = None, | ||||
|         clean_up_tokenization_spaces=True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not is_pandas_available(): | ||||
| @ -322,6 +326,7 @@ class TapasTokenizer(PreTrainedTokenizer): | ||||
|             max_question_length=max_question_length, | ||||
|             model_max_length=model_max_length, | ||||
|             additional_special_tokens=additional_special_tokens, | ||||
|             clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -1613,16 +1613,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): | ||||
|  | ||||
|         self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) | ||||
|  | ||||
|         if "clean_up_tokenization_spaces" not in kwargs: | ||||
|             warnings.warn( | ||||
|                 "`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This " | ||||
|                 "behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. " | ||||
|                 "For more details check this issue: https://github.com/huggingface/transformers/issues/31884", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|  | ||||
|         # By default, cleaning tokenization spaces for both fast and slow tokenizers | ||||
|         self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True) | ||||
|         self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False) | ||||
|  | ||||
|         # By default, do not split special tokens for both fast and slow tokenizers | ||||
|         self.split_special_tokens = kwargs.pop("split_special_tokens", False) | ||||
|  | ||||
| @ -1846,13 +1846,14 @@ class GenerationTesterMixin: | ||||
|                 input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict | ||||
|             ) | ||||
|             set_seed(seed) | ||||
|             num_hidden_layers = config.get_text_config().num_hidden_layers | ||||
|             if config.is_encoder_decoder: | ||||
|                 cache_cls = EncoderDecoderCache | ||||
|                 past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers)) | ||||
|                 past_key_values = cache_cls(DynamicCache(), DynamicCache()) | ||||
|                 past_key_values = cache_cls(DynamicCache(), DynamicCache()) | ||||
|             else: | ||||
|                 cache_cls = DynamicCache | ||||
|                 past_key_values = cache_cls(num_hidden_layers) | ||||
|                 past_key_values = cache_cls() | ||||
|  | ||||
|             new_results = model.generate( | ||||
|                 input_ids, | ||||
|                 attention_mask=attention_mask, | ||||
| @ -3797,6 +3798,29 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi | ||||
|         self.assertEqual(generated_text_no_padding, generated_text_with_padding) | ||||
|         self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") | ||||
|  | ||||
|     def test_generate_compile_fullgraph_tiny(self): | ||||
|         """ | ||||
|         Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash) | ||||
|         NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the | ||||
|         non-slow tests to prevent regressions! | ||||
|         """ | ||||
|         model = AutoModelForCausalLM.from_pretrained( | ||||
|             "hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto" | ||||
|         ) | ||||
|         tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") | ||||
|  | ||||
|         # compile generate | ||||
|         compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") | ||||
|  | ||||
|         # compiled generate does NOT accept parameterization except a) model inputs b) a generation config | ||||
|         generation_config = copy.deepcopy(model.generation_config) | ||||
|         generation_config.pad_token_id = model.config.eos_token_id | ||||
|  | ||||
|         model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt") | ||||
|         model_inputs = model_inputs.to(model.device) | ||||
|         gen_out = compiled_generate(**model_inputs, generation_config=generation_config) | ||||
|         self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1])  # some text was generated | ||||
|  | ||||
|  | ||||
| @require_torch | ||||
| class TokenHealingTestCase(unittest.TestCase): | ||||
|  | ||||
| @ -79,7 +79,7 @@ class ClvpTokenizationTest(TokenizerTesterMixin, unittest.TestCase): | ||||
|     # Copied from transformers.tests.models.gpt2.test_tokenization_gpt2.GPT2TokenizationTest.get_input_output_texts | ||||
|     def get_input_output_texts(self, tokenizer): | ||||
|         input_text = "lower newer" | ||||
|         output_text = "lower newer" | ||||
|         output_text = "lower[SPACE]newer" | ||||
|         return input_text, output_text | ||||
|  | ||||
|     # Copied from transformers.tests.models.layoutxlm.test_tokenization_layoutxlm.LayoutXLMTokenizationTest.test_add_special_tokens | ||||
|  | ||||
| @ -383,45 +383,73 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_beam_sample_generate_dict_output(): | ||||
|     def test_beam_sample_generate_dict_output(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_beam_search_generate_dict_output(): | ||||
|     def test_beam_search_generate_dict_output(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_constrained_beam_search_generate_dict_output(): | ||||
|     def test_constrained_beam_search_generate_dict_output(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_dola_decoding_sample(): | ||||
|     def test_dola_decoding_sample(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_generate_methods_with_num_logits_to_keep(): | ||||
|     def test_generate_methods_with_num_logits_to_keep(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_greedy_generate_dict_outputs(): | ||||
|     def test_greedy_generate_dict_outputs(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_group_beam_search_generate_dict_output(): | ||||
|     def test_group_beam_search_generate_dict_output(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_model_parallel_beam_search(): | ||||
|     def test_model_parallel_beam_search(self): | ||||
|         pass | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_new_cache_format_2(): | ||||
|         pass | ||||
|     @is_flaky()  # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) | ||||
|     def test_new_cache_format_0(self): | ||||
|         super().test_new_cache_format_0() | ||||
|  | ||||
|     @is_flaky()  # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) | ||||
|     def test_new_cache_format_1(self): | ||||
|         super().test_new_cache_format_1() | ||||
|  | ||||
|     @is_flaky()  # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) | ||||
|     def test_new_cache_format_2(self): | ||||
|         super().test_new_cache_format_2() | ||||
|  | ||||
|     @unittest.skip(reason="Failing test, need to fix") | ||||
|     def test_sample_generate_dict_output(): | ||||
|     def test_sample_generate_dict_output(self): | ||||
|         pass | ||||
|  | ||||
|     def test_generate_text_only_with_cache(self): | ||||
|         """ | ||||
|         Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature | ||||
|         required cache modifications (because layers are skipped in practice). This test should prevent regressions. | ||||
|         """ | ||||
|         config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||||
|  | ||||
|         for model_class in self.all_model_classes: | ||||
|             model = model_class(config) | ||||
|             model.to(torch_device) | ||||
|             model.eval() | ||||
|  | ||||
|             inputs = self._prepare_for_class(inputs_dict, model_class) | ||||
|  | ||||
|             input_ids = inputs["input_ids"] | ||||
|             del inputs["input_ids"] | ||||
|             del inputs["pixel_values"] | ||||
|  | ||||
|             model.generate(input_ids, use_cache=True) | ||||
|  | ||||
|  | ||||
| @require_torch | ||||
| class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase): | ||||
|  | ||||
| @ -15,6 +15,8 @@ | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from transformers import MllamaProcessor | ||||
| from transformers.testing_utils import require_torch, require_vision | ||||
| from transformers.utils import is_vision_available | ||||
| @ -177,3 +179,119 @@ class MllamaProcessorTest(unittest.TestCase): | ||||
|         rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False) | ||||
|         rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False) | ||||
|         self.assertEqual(rendered_list, rendered_str) | ||||
|  | ||||
|     def test_process_interleaved_images_prompts_image_splitting(self): | ||||
|         # Test that a single image is processed correctly | ||||
|         inputs = self.processor(images=self.image2, size={"width": 224, "height": 224}) | ||||
|         self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 224, 224)) | ||||
|  | ||||
|         # Test that text is processed correctly | ||||
|         text = "<|begin_of_text|>This is a test sentence.<|end_of_text|>" | ||||
|         inputs = self.processor(text=text) | ||||
|         expected_ids = [128000, 2028, 374, 264, 1296, 11914, 13, 128001] | ||||
|         self.assertEqual(inputs["input_ids"][0], expected_ids) | ||||
|         self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids)) | ||||
|         self.assertEqual(inputs.get("cross_attention_mask"), None) | ||||
|  | ||||
|         # Test a single sample with image and text | ||||
|         image_str = "<|image|>" | ||||
|         text_str = "This is a test sentence." | ||||
|         text = image_str + text_str | ||||
|         inputs = self.processor( | ||||
|             text=text, | ||||
|             images=self.image1, | ||||
|             size={"width": 128, "height": 128}, | ||||
|         ) | ||||
|         expected_ids = [self.image_token_id, self.bos_token_id] + [2028, 374, 264, 1296, 11914, 13] | ||||
|  | ||||
|         self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 128, 128)) | ||||
|         self.assertEqual(inputs["input_ids"][0], expected_ids) | ||||
|         self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids)) | ||||
|         cross_attention_mask = inputs["cross_attention_mask"] | ||||
|         self.assertEqual(cross_attention_mask.shape, (1, 8, 1, 4)) | ||||
|         self.assertTrue( | ||||
|             np.all(cross_attention_mask == 1), f"Cross attention mask is not all ones: {cross_attention_mask}" | ||||
|         ) | ||||
|  | ||||
|         # Test batch | ||||
|         text = [ | ||||
|             "<|image|>This is a test sentence.", | ||||
|             "This is a test sentence.<|image|><|image|>This is a test sentence.", | ||||
|         ] | ||||
|         # fmt: off | ||||
|         expected_ids = [ | ||||
|             [self.image_token_id, self.bos_token_id, 2028, 374, 264, 1296, 11914, 13], | ||||
|             [self.bos_token_id, 2028, 374, 264, 1296, 11914, 13, self.image_token_id, self.image_token_id, 2028, 374, 264, 1296, 11914, 13], | ||||
|         ] | ||||
|         # fmt: onn | ||||
|         images = [[self.image1], [self.image1, self.image2]] | ||||
|         inputs = self.processor(text=text, images=images, padding=True, size={"width": 256, "height": 256}) | ||||
|  | ||||
|         self.assertEqual(inputs["pixel_values"].shape, (2, 2, 4, 3, 256, 256)) | ||||
|         for input_ids_i, attention_mask_i, expected_ids_i in zip(inputs["input_ids"], inputs["attention_mask"], expected_ids): | ||||
|             pad_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 0] | ||||
|             input_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 1] | ||||
|             self.assertEqual(input_ids, expected_ids_i) | ||||
|             self.assertEqual(pad_ids, [self.pad_token_id] * len(pad_ids)) | ||||
|  | ||||
|         cross_attention_mask = inputs["cross_attention_mask"] | ||||
|         self.assertEqual(cross_attention_mask.shape, (2, 15, 2, 4)) | ||||
|  | ||||
|         # Check that only first tile of first sample is attended to all text tokens | ||||
|         first_sample_mask = cross_attention_mask[0].copy() | ||||
|         first_image_first_tile_attention = first_sample_mask[:, :1, :1]  # text tokens, images, tiles | ||||
|         self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}") | ||||
|  | ||||
|         # zero out first tile of first image | ||||
|         first_image_first_tile_attention[:, :1, :1] = 0 | ||||
|         self.assertTrue(np.all(first_image_first_tile_attention == 0), f"Cross attention mask is not all zeros: {first_image_first_tile_attention}") | ||||
|  | ||||
|         # second sample | ||||
|         second_sample_mask = cross_attention_mask[1].copy() | ||||
|         first_image_first_tile_attention = second_sample_mask[7:, :1, :1]  # text tokens, images, tiles | ||||
|         self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}") | ||||
|  | ||||
|         second_image_two_tiles_attention = second_sample_mask[8:, 1:2, :2]  # text tokens, images, tiles | ||||
|         self.assertTrue(np.all(second_image_two_tiles_attention == 1), f"Cross attention mask is not all ones: {second_image_two_tiles_attention}") | ||||
|  | ||||
|         # zero out both images masks | ||||
|         second_sample_mask[7:, :1, :1] = 0 | ||||
|         second_sample_mask[8:, 1:2, :2] = 0 | ||||
|         self.assertTrue(np.all(second_sample_mask == 0), f"Cross attention mask is not all zeros: {second_sample_mask}") | ||||
|  | ||||
|     def test_process_interleaved_images_prompts_image_error(self): | ||||
|         text = [ | ||||
|             "This is a test sentence.", | ||||
|             "In this other sentence we try some good things", | ||||
|         ] | ||||
|         inputs = self.processor(text=text, images=None, padding=True) | ||||
|         self.assertIsNotNone(inputs["input_ids"]) | ||||
|  | ||||
|         text = [ | ||||
|             "This is a test sentence.<|image|>", | ||||
|             "In this other sentence we try some good things", | ||||
|         ] | ||||
|         with self.assertRaises(ValueError): | ||||
|             self.processor(text=text, images=None, padding=True) | ||||
|  | ||||
|         images = [[self.image1], []] | ||||
|         with self.assertRaises(ValueError): | ||||
|             self.processor(text=text, images=images, padding=True) | ||||
|  | ||||
|         text = [ | ||||
|             "This is a test sentence.<|image|>", | ||||
|             "In this other sentence we try some good things<|image|>", | ||||
|         ] | ||||
|         with self.assertRaises(ValueError): | ||||
|             self.processor(text=text, images=None, padding=True) | ||||
|  | ||||
|         text = [ | ||||
|             "This is a test sentence.<|image|>", | ||||
|             "In this other sentence we try some good things<|image|>", | ||||
|         ] | ||||
|         images = [[self.image1], [self.image2]] | ||||
|         inputs = self.processor(text=text, images=images, padding=True) | ||||
|  | ||||
|         images = [[self.image1, self.image2], []] | ||||
|         with self.assertRaises(ValueError): | ||||
|             self.processor(text=text, images=None, padding=True) | ||||
|  | ||||
| @ -147,8 +147,8 @@ class Wav2Vec2TokenizerTest(unittest.TestCase): | ||||
|         batch_tokens = tokenizer.batch_decode(sample_ids) | ||||
|         batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True) | ||||
|  | ||||
|         self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"]) | ||||
|         self.assertEqual(batch_tokens_2, ["HELO!?!?", "BYE BYE"]) | ||||
|         self.assertEqual(batch_tokens, ["HELLO<unk>!? !?$$$", "BYE BYE<unk>$$$"]) | ||||
|         self.assertEqual(batch_tokens_2, ["HELO!? !?", "BYE BYE"]) | ||||
|  | ||||
|     def test_call(self): | ||||
|         # Tests that all call wrap to encode_plus and batch_encode_plus | ||||
| @ -467,8 +467,8 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): | ||||
|         batch_tokens = tokenizer.batch_decode(sample_ids) | ||||
|         batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True) | ||||
|  | ||||
|         self.assertEqual(batch_tokens, ["HELLO<unk>!?!?<new_tokens>$$$", "BYE BYE<unk><new_tokens>$$$"]) | ||||
|         self.assertEqual(batch_tokens_2, ["HELO!?!?<new_tokens>", "BYE BYE<new_tokens>"]) | ||||
|         self.assertEqual(batch_tokens, ["HELLO<unk>!? !?<new_tokens>$$$", "BYE BYE<unk><new_tokens>$$$"]) | ||||
|         self.assertEqual(batch_tokens_2, ["HELO!? !?<new_tokens>", "BYE BYE<new_tokens>"]) | ||||
|  | ||||
|     def test_special_characters_in_vocab(self): | ||||
|         sent = "ʈʰ æ æ̃ ˧ kʰ" | ||||
|  | ||||
| @ -249,7 +249,7 @@ class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): | ||||
|         # fmt: on | ||||
|  | ||||
|         batch_tokens = tokenizer.batch_decode(sample_ids) | ||||
|         self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ!?!? $$$", "j ð s j ð s oːɹ $$$"]) | ||||
|         self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ ! ? ! ? $$$", "j ð s j ð s oːɹ $$$"]) | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_from_offsets(offsets, key): | ||||
|  | ||||
| @ -53,7 +53,7 @@ class CacheTest(unittest.TestCase): | ||||
|     def test_dynamic_cache_retrocompatibility(self): | ||||
|         """Tests that we can convert back and forth between the legacy cache format and DynamicCache""" | ||||
|         legacy_cache = () | ||||
|         new_cache = DynamicCache(num_hidden_layers=10) | ||||
|         new_cache = DynamicCache() | ||||
|  | ||||
|         # Creates a new cache with 10 layers in both formats | ||||
|         for layer_idx in range(10): | ||||
| @ -83,7 +83,7 @@ class CacheTest(unittest.TestCase): | ||||
|                 ) | ||||
|  | ||||
|         # Test 1: We can convert from legacy to new with no changes | ||||
|         from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10) | ||||
|         from_legacy = DynamicCache.from_legacy_cache(legacy_cache) | ||||
|         for layer_idx in range(10): | ||||
|             for key_value_idx in range(2): | ||||
|                 self.assertTrue( | ||||
| @ -103,7 +103,7 @@ class CacheTest(unittest.TestCase): | ||||
|         legacy_reorder_fn = GPT2LMHeadModel._reorder_cache  # An example of a legacy `_reorder_cache` function | ||||
|  | ||||
|         legacy_cache = () | ||||
|         new_cache = DynamicCache(num_hidden_layers=10) | ||||
|         new_cache = DynamicCache() | ||||
|  | ||||
|         # Creates a new cache with 10 layers in both formats | ||||
|         for layer_idx in range(10): | ||||
| @ -240,9 +240,7 @@ class CacheIntegrationTest(unittest.TestCase): | ||||
|         set_seed(0) | ||||
|         gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) | ||||
|         set_seed(0) | ||||
|         gen_out = model.generate( | ||||
|             **inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers) | ||||
|         ) | ||||
|         gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache()) | ||||
|         self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist()) | ||||
|  | ||||
|         decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) | ||||
| @ -270,9 +268,7 @@ class CacheIntegrationTest(unittest.TestCase): | ||||
|             model.device | ||||
|         ) | ||||
|  | ||||
|         gen_out = model.generate( | ||||
|             **inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers) | ||||
|         ) | ||||
|         gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache()) | ||||
|         decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) | ||||
|         expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] | ||||
|         self.assertListEqual(decoded, expected_text) | ||||
|  | ||||
| @ -313,11 +313,12 @@ class ConfigTestUtils(unittest.TestCase): | ||||
|         old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo) | ||||
|         self.assertEqual(old_configuration.hidden_size, 768) | ||||
|  | ||||
|     def test_saving_config_with_custom_generation_kwargs_raises_exception(self): | ||||
|     def test_saving_config_with_custom_generation_kwargs_raises_warning(self): | ||||
|         config = BertConfig(min_length=3)  # `min_length = 3` is a non-default generation kwarg | ||||
|         with tempfile.TemporaryDirectory() as tmp_dir: | ||||
|             with self.assertRaises(ValueError): | ||||
|             with self.assertWarns(UserWarning) as cm: | ||||
|                 config.save_pretrained(tmp_dir) | ||||
|             self.assertIn("min_length", str(cm.warning)) | ||||
|  | ||||
|     def test_get_non_default_generation_parameters(self): | ||||
|         config = BertConfig() | ||||
|  | ||||
| @ -65,6 +65,19 @@ class RopeTest(unittest.TestCase): | ||||
|                     with self.assertRaises(KeyError): | ||||
|                         rope_config_validation(config) | ||||
|  | ||||
|         # Any other parameters passed to RoPE will raise a warning that a particular key is not used | ||||
|         # But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys` | ||||
|         model_specific_kwarg = "mrope_sections"  # e,g in Qwen2-VL | ||||
|  | ||||
|         for rope_type in all_rope_types: | ||||
|             if rope_type == "default": | ||||
|                 config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True} | ||||
|                 rope_config_validation(config, ignore_keys={model_specific_kwarg}) | ||||
|                 with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: | ||||
|                     rope_config_validation(config) | ||||
|                     self.assertEqual(len(logs.output), 1) | ||||
|                     self.assertIn(model_specific_kwarg, logs.output[0]) | ||||
|  | ||||
|     def test_default_rope_function_bc(self): | ||||
|         config = LlamaConfig() | ||||
|         device = torch_device | ||||
|  | ||||
| @ -1718,29 +1718,51 @@ class ModelUtilsTest(TestCasePlus): | ||||
|  | ||||
|     def test_can_generate(self): | ||||
|         """Tests the behavior of `PreTrainedModel.can_generate` method.""" | ||||
|         logger = logging.get_logger("transformers.modeling_utils") | ||||
|         logger.warning_once.cache_clear() | ||||
|  | ||||
|         # 1 - By default, a model CAN'T generate | ||||
|         self.assertFalse(BertModel.can_generate()) | ||||
|         can_generate = BertModel.can_generate() | ||||
|         self.assertFalse(can_generate) | ||||
|  | ||||
|         # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly | ||||
|         class DummyBertWithMixin(BertModel, GenerationMixin): | ||||
|             pass | ||||
|  | ||||
|         self.assertTrue(DummyBertWithMixin.can_generate()) | ||||
|         with CaptureLogger(logger) as cl: | ||||
|             can_generate = DummyBertWithMixin.can_generate() | ||||
|         self.assertTrue("" == cl.out) | ||||
|         self.assertTrue(can_generate) | ||||
|  | ||||
|         # 3 - Alternatively, a model can implement a `generate` method | ||||
|         class DummyBertWithGenerate(BertModel): | ||||
|             def generate(self): | ||||
|                 pass | ||||
|  | ||||
|         self.assertTrue(DummyBertWithGenerate.can_generate()) | ||||
|         with CaptureLogger(logger) as cl: | ||||
|             can_generate = DummyBertWithGenerate.can_generate() | ||||
|         self.assertTrue("" == cl.out) | ||||
|         self.assertTrue(can_generate) | ||||
|  | ||||
|         # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited | ||||
|         # 4 - Finally, it can inherit from a model that can generate | ||||
|         class DummyBertWithParent(DummyBertWithMixin): | ||||
|             pass | ||||
|  | ||||
|         with CaptureLogger(logger) as cl: | ||||
|             can_generate = DummyBertWithParent.can_generate() | ||||
|         self.assertTrue("" == cl.out) | ||||
|         self.assertTrue(can_generate) | ||||
|  | ||||
|         # 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited | ||||
|         # `GenerationMixin`) | ||||
|         class DummyBertWithPrepareInputs(BertModel): | ||||
|             def prepare_inputs_for_generation(self): | ||||
|                 pass | ||||
|  | ||||
|         self.assertTrue(DummyBertWithPrepareInputs.can_generate()) | ||||
|         with CaptureLogger(logger) as cl: | ||||
|             can_generate = DummyBertWithPrepareInputs.can_generate() | ||||
|         self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out) | ||||
|         self.assertTrue(can_generate) | ||||
|  | ||||
|     def test_save_and_load_config_with_custom_generation(self): | ||||
|         """ | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	