Compare commits

...

8 Commits

Author SHA1 Message Date
241c04d368 [Whisper] patch float type on mps (#35295)
* fix float type on mps

* make
2024-12-16 17:35:57 +01:00
1b6cb1eefc don't use no_sync when deepspeed doesn't support it for certain zero stages (#35157)
* don't use no_sync when deepspeed doesn't support it for certain zero stages

* chore: lint

* fix no_sync context for deepspeed across all zero types

* chore: lint
2024-12-13 19:23:54 +01:00
b1d5d6dd65 Fix FSDP no longer working (#35212)
Fix FSDP failing
2024-12-13 19:23:40 +01:00
bf5d7c3fa3 Only import torch.distributed if it is available (#35133) 2024-12-10 18:39:05 +01:00
49952300bb v4.47.1 2024-12-10 08:41:52 +01:00
d5ccfcc39a Fix num_items_in_batch not being an integer (#35115)
In method `Trainer#get_batch_samples`, the return values should be a
list of batch samples and an integer indicating the number of items that
exist in the batch. However, this was not actually a case and what was
returned instead of an integer, was a tensor with one element. In the
multi-GPU setup, this tensor is placed in a different device than the
loss tensor, causing the loss function to raise a `RuntimeError`.

The problem arises from
5d7739f15a/src/transformers/trainer.py (L5139-L5144),
where the outer `sum` operates over a list of tensors which means that
the final result is also a tensor. To counter this issue, a new check
(after the accelerator gathering) has been added in order to convert a
potential tensor to an integer before returning the
`num_items_in_batch`.
2024-12-10 08:41:19 +01:00
0485b6e881 Fix GA loss bugs and add unit test (#35121)
* fix GA bugs and add unit test

* narrow down model loss unit test diff gap

* format code to make ruff happy

* send num_items_in_batch argument to decoder

* fix GA loss bug in BertLMHeadModel

* use TinyStories-33M to narrow down diff gap

* fotmat code

* missing .config

* avoid add extra args

---------

Co-authored-by: kangsheng <kangsheng@meituan.com>
2024-12-09 11:19:20 +01:00
5d7739f15a Release: v4.47.0 2024-12-05 18:17:54 +01:00
59 changed files with 187 additions and 86 deletions

View File

@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)

View File

@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)

View File

@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = logging.getLogger(__name__)

View File

@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = logging.getLogger(__name__)

View File

@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.47.0.dev0")
check_min_version("4.47.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -435,7 +435,7 @@ install_requires = [
setup(
name="transformers",
version="4.47.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.47.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.47.0.dev0"
__version__ = "4.47.1"
from typing import TYPE_CHECKING

View File

@ -1325,6 +1325,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@ -1375,11 +1376,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
if not return_dict:
output = (prediction_scores,) + outputs[2:]

View File

@ -491,6 +491,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if "num_items_in_batch" in kwargs_encoder:
kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
if encoder_outputs is None:
if inputs is None:

View File

@ -632,7 +632,9 @@ class WhisperGenerationMixin(GenerationMixin):
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
time_offset = seek.to(torch.float64) * time_precision / input_stride
time_offset = (
seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
)
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
# 6.2 cut out next 30s segment from input features
@ -1805,6 +1807,7 @@ class WhisperGenerationMixin(GenerationMixin):
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
device = seek_sequence.device
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
@ -1828,8 +1831,12 @@ class WhisperGenerationMixin(GenerationMixin):
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
"start": time_offset[prev_idx]
+ start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
* time_precision,
"end": time_offset[prev_idx]
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
* time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
@ -1856,7 +1863,9 @@ class WhisperGenerationMixin(GenerationMixin):
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
torch.float32 if device.type == "mps" else torch.float64
)
segments = [
{
"start": time_offset[prev_idx],

View File

@ -38,8 +38,10 @@ is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse(
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5"):
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,

View File

@ -2251,7 +2251,7 @@ class Trainer:
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
@ -2304,12 +2304,13 @@ class Trainer:
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if self.accelerator.mixed_precision != "fp8":
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
@ -2516,6 +2517,7 @@ class Trainer:
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
and self.accelerator.distributed_type != DistributedType.DEEPSPEED
else contextlib.nullcontext
)
with context():
@ -3649,10 +3651,7 @@ class Trainer:
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
@ -4175,7 +4174,7 @@ class Trainer:
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled or self.is_fsdp_enabled
if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
@ -5132,10 +5131,6 @@ class Trainer:
except StopIteration:
break
# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
@ -5145,4 +5140,8 @@ class Trainer:
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.item()
return batch_samples, num_items_in_batch

View File

@ -750,11 +750,102 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model, alternate_seed=True)
@slow
def test_gradient_accumulation_loss_alignment(self):
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42)
import datasets
model_name = "distilgpt2"
model_name = "nickypro/tinyllama-110M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
base_loss_callback = StoreLossCallback()
args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 20,
"learning_rate": 3e-4,
"disable_tqdm": True,
}
args = TrainingArguments(
"./generation",
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
data_collator=data_collator,
)
assert trainer.model_accepts_loss_kwargs
trainer.train()
grad_accum_loss_callback = StoreLossCallback()
args = TrainingArguments(
"./generation",
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
data_collator=data_collator,
)
trainer.train()
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
broken_loss_callback = StoreLossCallback()
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
data_collator=data_collator,
)
# disable model_accepts_loss_kwargs
trainer.model_accepts_loss_kwargs = False
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
@slow
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
set_seed(42)
import datasets
model_name = "roneneldan/TinyStories-33M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
@ -836,15 +927,16 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# These should be quite close
for diff in diff_truth:
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# These should be very off
for diff in diff_broken:
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same training losses.