Compare commits

...

1 Commits

Author SHA1 Message Date
d7a0632979 make it optional 2025-02-26 16:45:52 +01:00

View File

@ -362,7 +362,7 @@ class Trainer:
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
inner layers, dropout probabilities etc).
compute_loss_func (`Callable`, *optional*):
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
A function that accepts the raw model outputs, labels, and optionally the number of items in the entire accumulated
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`].
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
@ -3778,7 +3778,10 @@ class Trainer:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
if self.compute_loss_func is not None:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
compute_loss_kwargs = {}
if "num_items_in_batch" in inspect.signature(self.compute_loss_func).parameters.keys():
compute_loss_kwargs["num_items_in_batch"] = num_items_in_batch
loss = self.compute_loss_func(outputs, labels, **compute_loss_kwargs)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else: