mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 09:03:53 +08:00
[Trainer] [Breaking change] use_cache
default to False
(#41585)
* use_cache default to `False` when training * style * Fix comment * add checks * style * set * switch
This commit is contained in:
@ -738,6 +738,10 @@ class Trainer:
|
||||
self._train_batch_size = args.train_batch_size
|
||||
self._created_lr_scheduler = False
|
||||
|
||||
# Set use_cache for the model
|
||||
if getattr(self.model, "config", None) is not None:
|
||||
self.model.config.use_cache = self.args.use_cache
|
||||
|
||||
# very last
|
||||
self._memory_tracker.stop_and_update_metrics()
|
||||
|
||||
|
@ -752,6 +752,10 @@ class TrainingArguments:
|
||||
Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
|
||||
num_tokens_in_batch for precise loss calculation. Reference:
|
||||
https://github.com/huggingface/transformers/issues/34242
|
||||
|
||||
use_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to enable cache for the model. For training, this is usually not needed apart from some PEFT methods that uses `past_key_values`.
|
||||
|
||||
"""
|
||||
|
||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||
@ -1382,6 +1386,13 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
use_cache: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether or not to use cache for the model For training, this is usually not needed apart from some PEFT methods that uses `past_key_values`."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Set default output_dir if not provided
|
||||
if self.output_dir is None:
|
||||
|
Reference in New Issue
Block a user