mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-10-26 22:14:33 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			vision_vis
			...
			v4.44.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 984bc11b08 | |||
| af61272239 | |||
| 3e93524a13 | 
| @ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| Array = Any | Array = Any | ||||||
| Dataset = datasets.arrow_dataset.Dataset | Dataset = datasets.arrow_dataset.Dataset | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risk. | # Will error if the minimal version of Transformers is not installed. Remove at your own risk. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| Array = Any | Array = Any | ||||||
| Dataset = datasets.arrow_dataset.Dataset | Dataset = datasets.arrow_dataset.Dataset | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -45,7 +45,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") | require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -43,7 +43,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used. | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used. | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -58,7 +58,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_ | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_ | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| # You should update this to your particular problem to have better documentation of `model_type` | # You should update this to your particular problem to have better documentation of `model_type` | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logging.basicConfig(level=logging.INFO) | logging.basicConfig(level=logging.INFO) | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,7 +53,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||||
|  | |||||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version( | require_version( | ||||||
|     "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" |     "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_ | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError): | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,7 +53,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
| # region Checking dependencies | # region Checking dependencies | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| task_to_keys = { | task_to_keys = { | ||||||
|     "cola": ("sentence", None), |     "cola": ("sentence", None), | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
| # region Dependencies and constants | # region Dependencies and constants | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.44.0.dev0") | check_min_version("4.44.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @ -430,7 +430,7 @@ install_requires = [ | |||||||
|  |  | ||||||
| setup( | setup( | ||||||
|     name="transformers", |     name="transformers", | ||||||
|     version="4.44.0.dev0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) |     version="4.44.0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) | ||||||
|     author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", |     author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", | ||||||
|     author_email="transformers@huggingface.co", |     author_email="transformers@huggingface.co", | ||||||
|     description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", |     description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ | |||||||
| # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names | # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names | ||||||
| # in the namespace without actually importing anything (and especially none of the backends). | # in the namespace without actually importing anything (and especially none of the backends). | ||||||
|  |  | ||||||
| __version__ = "4.44.0.dev0" | __version__ = "4.44.0" | ||||||
|  |  | ||||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
|  | |||||||
| @ -932,8 +932,6 @@ def _load_state_dict_into_meta_model( | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         ): |         ): | ||||||
|             if is_fsdp_enabled(): |  | ||||||
|                 param_device = "cpu" if is_local_dist_rank_0() else "meta" |  | ||||||
|             # For backward compatibility with older versions of `accelerate` and for non-quantized params |             # For backward compatibility with older versions of `accelerate` and for non-quantized params | ||||||
|             set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) |             set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) | ||||||
|         else: |         else: | ||||||
| @ -944,10 +942,7 @@ def _load_state_dict_into_meta_model( | |||||||
|             if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): |             if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): | ||||||
|                 module, tensor_name = get_module_from_name(model, param_name) |                 module, tensor_name = get_module_from_name(model, param_name) | ||||||
|                 value = getattr(module, tensor_name) |                 value = getattr(module, tensor_name) | ||||||
|                 param_to = "cpu" |                 value = type(value)(value.data.to("cpu"), **value.__dict__) | ||||||
|                 if is_fsdp_enabled() and not is_local_dist_rank_0(): |  | ||||||
|                     param_to = "meta" |  | ||||||
|                 value = type(value)(value.data.to(param_to), **value.__dict__) |  | ||||||
|                 setattr(module, tensor_name, value) |                 setattr(module, tensor_name, value) | ||||||
|             # TODO: consider removing used param_parts from state_dict before return |             # TODO: consider removing used param_parts from state_dict before return | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,6 +53,60 @@ logger = logging.get_logger(__name__) | |||||||
| _CONFIG_FOR_DOC = "NemotronConfig" | _CONFIG_FOR_DOC = "NemotronConfig" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position | ||||||
|  | def _prepare_4d_causal_attention_mask_with_cache_position( | ||||||
|  |     attention_mask: torch.Tensor, | ||||||
|  |     sequence_length: int, | ||||||
|  |     target_length: int, | ||||||
|  |     dtype: torch.dtype, | ||||||
|  |     device: torch.device, | ||||||
|  |     min_dtype: float, | ||||||
|  |     cache_position: torch.Tensor, | ||||||
|  |     batch_size: int, | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | ||||||
|  |     `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         attention_mask (`torch.Tensor`): | ||||||
|  |             A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. | ||||||
|  |         sequence_length (`int`): | ||||||
|  |             The sequence length being processed. | ||||||
|  |         target_length (`int`): | ||||||
|  |             The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. | ||||||
|  |         dtype (`torch.dtype`): | ||||||
|  |             The dtype to use for the 4D attention mask. | ||||||
|  |         device (`torch.device`): | ||||||
|  |             The device to plcae the 4D attention mask on. | ||||||
|  |         min_dtype (`float`): | ||||||
|  |             The minimum value representable with the dtype `dtype`. | ||||||
|  |         cache_position (`torch.Tensor`): | ||||||
|  |             Indices depicting the position of the input sequence tokens in the sequence. | ||||||
|  |         batch_size (`torch.Tensor`): | ||||||
|  |             Batch size. | ||||||
|  |     """ | ||||||
|  |     if attention_mask is not None and attention_mask.dim() == 4: | ||||||
|  |         # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. | ||||||
|  |         causal_mask = attention_mask | ||||||
|  |     else: | ||||||
|  |         causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) | ||||||
|  |         if sequence_length != 1: | ||||||
|  |             causal_mask = torch.triu(causal_mask, diagonal=1) | ||||||
|  |         causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | ||||||
|  |         causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) | ||||||
|  |         if attention_mask is not None: | ||||||
|  |             causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit | ||||||
|  |             mask_length = attention_mask.shape[-1] | ||||||
|  |             padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] | ||||||
|  |             padding_mask = padding_mask == 0 | ||||||
|  |             causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( | ||||||
|  |                 padding_mask, min_dtype | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     return causal_mask | ||||||
|  |  | ||||||
|  |  | ||||||
| def _cast_if_autocast_enabled(*args): | def _cast_if_autocast_enabled(*args): | ||||||
|     if not torch.is_autocast_enabled(): |     if not torch.is_autocast_enabled(): | ||||||
|         return args |         return args | ||||||
| @ -902,27 +956,18 @@ class NemotronModel(NemotronPreTrainedModel): | |||||||
|                 else past_seen_tokens + sequence_length + 1 |                 else past_seen_tokens + sequence_length + 1 | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         if attention_mask is not None and attention_mask.dim() == 4: |         # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). | ||||||
|             # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing |         causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( | ||||||
|             if attention_mask.max() != 0: |             attention_mask, | ||||||
|                 raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") |             sequence_length=sequence_length, | ||||||
|             causal_mask = attention_mask |             target_length=target_length, | ||||||
|         else: |             dtype=dtype, | ||||||
|             causal_mask = torch.full( |             device=device, | ||||||
|                 (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |             min_dtype=min_dtype, | ||||||
|             ) |             cache_position=cache_position, | ||||||
|             if sequence_length != 1: |             batch_size=input_tensor.shape[0], | ||||||
|                 causal_mask = torch.triu(causal_mask, diagonal=1) |         ) | ||||||
|             causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |  | ||||||
|             causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) |  | ||||||
|             if attention_mask is not None: |  | ||||||
|                 causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit |  | ||||||
|                 mask_length = attention_mask.shape[-1] |  | ||||||
|                 padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |  | ||||||
|                 padding_mask = padding_mask == 0 |  | ||||||
|                 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |  | ||||||
|                     padding_mask, min_dtype |  | ||||||
|                 ) |  | ||||||
|         if ( |         if ( | ||||||
|             self.config._attn_implementation == "sdpa" |             self.config._attn_implementation == "sdpa" | ||||||
|             and attention_mask is not None |             and attention_mask is not None | ||||||
| @ -1086,11 +1131,36 @@ class NemotronForCausalLM(NemotronPreTrainedModel): | |||||||
|             if past_key_values: |             if past_key_values: | ||||||
|                 position_ids = position_ids[:, -input_ids.shape[1] :] |                 position_ids = position_ids[:, -input_ids.shape[1] :] | ||||||
|  |  | ||||||
|  |                 # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s  `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. | ||||||
|  |                 position_ids = position_ids.clone(memory_format=torch.contiguous_format) | ||||||
|  |  | ||||||
|         # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |         # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||||||
|         if inputs_embeds is not None and cache_position[0] == 0: |         if inputs_embeds is not None and cache_position[0] == 0: | ||||||
|             model_inputs = {"inputs_embeds": inputs_embeds} |             model_inputs = {"inputs_embeds": inputs_embeds} | ||||||
|         else: |         else: | ||||||
|             model_inputs = {"input_ids": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases |             model_inputs = {"input_ids": input_ids} | ||||||
|  |  | ||||||
|  |         if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: | ||||||
|  |             if inputs_embeds is not None: | ||||||
|  |                 batch_size, sequence_length = inputs_embeds.shape | ||||||
|  |                 device = inputs_embeds.device | ||||||
|  |             else: | ||||||
|  |                 batch_size, sequence_length = input_ids.shape | ||||||
|  |                 device = input_ids.device | ||||||
|  |  | ||||||
|  |             dtype = self.lm_head.weight.dtype | ||||||
|  |             min_dtype = torch.finfo(dtype).min | ||||||
|  |  | ||||||
|  |             attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( | ||||||
|  |                 attention_mask, | ||||||
|  |                 sequence_length=sequence_length, | ||||||
|  |                 target_length=past_key_values.get_max_length(), | ||||||
|  |                 dtype=dtype, | ||||||
|  |                 device=device, | ||||||
|  |                 min_dtype=min_dtype, | ||||||
|  |                 cache_position=cache_position, | ||||||
|  |                 batch_size=batch_size, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         model_inputs.update( |         model_inputs.update( | ||||||
|             { |             { | ||||||
|  | |||||||
| @ -12,7 +12,6 @@ | |||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import importlib | import importlib | ||||||
| import inspect |  | ||||||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||||||
|  |  | ||||||
| from packaging import version | from packaging import version | ||||||
| @ -200,16 +199,11 @@ class Bnb4BitHfQuantizer(HfQuantizer): | |||||||
|                     if unexpected_keys is not None and k in unexpected_keys: |                     if unexpected_keys is not None and k in unexpected_keys: | ||||||
|                         unexpected_keys.remove(k) |                         unexpected_keys.remove(k) | ||||||
|  |  | ||||||
|             param_kwargs = {} |  | ||||||
|             sig = inspect.signature(bnb.nn.Params4bit.from_prequantized) |  | ||||||
|             if "module" in sig.parameters: |  | ||||||
|                 param_kwargs["module"] = module |  | ||||||
|             new_value = bnb.nn.Params4bit.from_prequantized( |             new_value = bnb.nn.Params4bit.from_prequantized( | ||||||
|                 data=param_value, |                 data=param_value, | ||||||
|                 quantized_stats=quantized_stats, |                 quantized_stats=quantized_stats, | ||||||
|                 requires_grad=False, |                 requires_grad=False, | ||||||
|                 device=target_device, |                 device=target_device, | ||||||
|                 **param_kwargs, |  | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             new_value = param_value.to("cpu") |             new_value = param_value.to("cpu") | ||||||
|  | |||||||
| @ -692,7 +692,7 @@ def is_torchdynamo_compiling(): | |||||||
|         import torch |         import torch | ||||||
|  |  | ||||||
|         return torch.compiler.is_compiling() |         return torch.compiler.is_compiling() | ||||||
|     except AttributeError: |     except Exception: | ||||||
|         try: |         try: | ||||||
|             import torch._dynamo as dynamo  # noqa: F401 |             import torch._dynamo as dynamo  # noqa: F401 | ||||||
|  |  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	