Compare commits

...

3 Commits

View File

@ -29,7 +29,12 @@ logger = logging.getLogger(__name__)
@add_start_docstrings(TrainingArguments.__doc__)
class Seq2SeqTrainingArguments(TrainingArguments):
"""
Args:
sortish_sampler (`bool`, *optional*, defaults to `False`):
Whether to use a *sortish sampler* or not. Only possible if the underlying datasets are *Seq2SeqDataset*
for now but will become generally available in the near future.
It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness
for the training set.
predict_with_generate (`bool`, *optional*, defaults to `False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
generation_max_length (`int`, *optional*):
@ -46,7 +51,7 @@ class Seq2SeqTrainingArguments(TrainingArguments):
- a path to a *directory* containing a configuration file saved using the
[`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
- a [`~generation.GenerationConfig`] object.
"""
""" # fmt: skip # Prevent Ruff from altering the indentation
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
predict_with_generate: bool = field(