mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
8 Commits
change_bui
...
v4.47.1
Author | SHA1 | Date | |
---|---|---|---|
241c04d368 | |||
1b6cb1eefc | |||
b1d5d6dd65 | |||
bf5d7c3fa3 | |||
49952300bb | |||
d5ccfcc39a | |||
0485b6e881 | |||
5d7739f15a |
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.47.0.dev0")
|
||||
check_min_version("4.47.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -435,7 +435,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.47.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.47.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.47.0.dev0"
|
||||
__version__ = "4.47.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -1325,6 +1325,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -1375,11 +1376,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
|
@ -491,6 +491,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
if "num_items_in_batch" in kwargs_encoder:
|
||||
kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
|
||||
|
||||
if encoder_outputs is None:
|
||||
if inputs is None:
|
||||
|
@ -632,7 +632,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
cur_bsz=cur_bsz,
|
||||
batch_idx_map=batch_idx_map,
|
||||
)
|
||||
time_offset = seek.to(torch.float64) * time_precision / input_stride
|
||||
time_offset = (
|
||||
seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
|
||||
)
|
||||
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
||||
|
||||
# 6.2 cut out next 30s segment from input features
|
||||
@ -1805,6 +1807,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
timestamp_segment_indices.add_(1)
|
||||
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
|
||||
device = seek_sequence.device
|
||||
|
||||
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
|
||||
# "end of segment" prediction and slice the decoding into segments accordingly
|
||||
@ -1828,8 +1831,12 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
|
||||
segments.append(
|
||||
{
|
||||
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
|
||||
"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
|
||||
"start": time_offset[prev_idx]
|
||||
+ start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
|
||||
* time_precision,
|
||||
"end": time_offset[prev_idx]
|
||||
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
|
||||
* time_precision,
|
||||
"tokens": sliced_tokens,
|
||||
"result": seek_outputs[idx],
|
||||
}
|
||||
@ -1856,7 +1863,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
|
||||
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
|
||||
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
|
||||
torch.float32 if device.type == "mps" else torch.float64
|
||||
)
|
||||
segments = [
|
||||
{
|
||||
"start": time_offset[prev_idx],
|
||||
|
@ -38,8 +38,10 @@ is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse(
|
||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||
|
||||
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
if is_torch_greater_or_equal("2.5"):
|
||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
|
@ -2251,7 +2251,7 @@ class Trainer:
|
||||
else:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
|
||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||
if self._created_lr_scheduler:
|
||||
@ -2304,12 +2304,13 @@ class Trainer:
|
||||
# In case of auto_find_batch_size=True
|
||||
# Remove FSDP wrapping from sub-models.
|
||||
self.model = unwrap_model(self.model, recursive=True)
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
|
||||
if delay_optimizer_creation:
|
||||
if use_accelerator_prepare:
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
if self.accelerator.mixed_precision != "fp8":
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
# prepare using `accelerator` prepare
|
||||
@ -2516,6 +2517,7 @@ class Trainer:
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1
|
||||
and self.accelerator.distributed_type != DistributedType.DEEPSPEED
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
@ -3649,10 +3651,7 @@ class Trainer:
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
@ -4175,7 +4174,7 @@ class Trainer:
|
||||
start_time = time.time()
|
||||
model = (
|
||||
self.accelerator.prepare(model)
|
||||
if self.is_deepspeed_enabled or self.is_fsdp_enabled
|
||||
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
|
||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||
)
|
||||
self.model_preparation_time = round(time.time() - start_time, 4)
|
||||
@ -5132,10 +5131,6 @@ class Trainer:
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Keep default behavior the same
|
||||
if not self.model_accepts_loss_kwargs:
|
||||
return batch_samples, None
|
||||
|
||||
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
||||
# For now we don't support object detection
|
||||
try:
|
||||
@ -5145,4 +5140,8 @@ class Trainer:
|
||||
|
||||
if self.args.average_tokens_across_devices:
|
||||
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
|
||||
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.item()
|
||||
|
||||
return batch_samples, num_items_in_batch
|
||||
|
@ -750,11 +750,102 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment(self):
|
||||
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
|
||||
set_seed(42)
|
||||
import datasets
|
||||
|
||||
model_name = "distilgpt2"
|
||||
model_name = "nickypro/tinyllama-110M"
|
||||
dataset_name = "wikitext"
|
||||
dataset_config = "wikitext-2-raw-v1"
|
||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||
dataset = dataset.train_test_split(test_size=0.2)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
base_loss_callback = StoreLossCallback()
|
||||
|
||||
args_kwargs = {
|
||||
"report_to": "none",
|
||||
"logging_steps": 1,
|
||||
"max_steps": 20,
|
||||
"learning_rate": 3e-4,
|
||||
"disable_tqdm": True,
|
||||
}
|
||||
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[base_loss_callback],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
assert trainer.model_accepts_loss_kwargs
|
||||
trainer.train()
|
||||
|
||||
grad_accum_loss_callback = StoreLossCallback()
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
gradient_accumulation_steps=2,
|
||||
per_device_train_batch_size=4,
|
||||
)
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[grad_accum_loss_callback],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
broken_loss_callback = StoreLossCallback()
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[broken_loss_callback],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
# disable model_accepts_loss_kwargs
|
||||
trainer.model_accepts_loss_kwargs = False
|
||||
trainer.train()
|
||||
|
||||
# Calculate the difference between the base loss and the grad_accum loss
|
||||
diff_truth = [
|
||||
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
||||
]
|
||||
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
|
||||
# all diff truth should be quite close
|
||||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||
|
||||
# max diff broken should be very off
|
||||
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
||||
set_seed(42)
|
||||
import datasets
|
||||
|
||||
model_name = "roneneldan/TinyStories-33M"
|
||||
dataset_name = "wikitext"
|
||||
dataset_config = "wikitext-2-raw-v1"
|
||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||
@ -836,15 +927,16 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
|
||||
# Calculate the difference between the base loss and the grad_accum loss
|
||||
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
|
||||
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
# These should be quite close
|
||||
for diff in diff_truth:
|
||||
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
|
||||
diff_truth = [
|
||||
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
||||
]
|
||||
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
|
||||
# These should be very off
|
||||
for diff in diff_broken:
|
||||
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
|
||||
# all diff truth should be quite close
|
||||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||
|
||||
# max diff broken should be very off
|
||||
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
|
||||
|
||||
def test_gradient_accumulation(self):
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
|
||||
|
Reference in New Issue
Block a user