Add support for specifying revisions when pushing to Hub via internal Trainer call (#36852)
* Update training_args.py * Update trainer.py * fixes * fix * remove extraneous comments * explicit revision arg * add msg * fixup * fix field name * rename field revision to hub_revision * restore gradient_checkpointing doc * fix ws --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@ -3938,7 +3938,7 @@ class Trainer:
|
||||
|
||||
# Push to the Hub when `save_model` is called by the user.
|
||||
if self.args.push_to_hub and not _internal_call:
|
||||
self.push_to_hub(commit_message="Model save")
|
||||
self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision)
|
||||
|
||||
def _save_tpu(self, output_dir: Optional[str] = None):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
@ -4788,6 +4788,7 @@ class Trainer:
|
||||
token=self.args.hub_token,
|
||||
run_as_future=True,
|
||||
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
|
||||
revision=self.args.hub_revision,
|
||||
)
|
||||
|
||||
push_jobs = [model_push_job]
|
||||
@ -4803,6 +4804,7 @@ class Trainer:
|
||||
commit_message=commit_message + ", checkpoint",
|
||||
token=self.args.hub_token,
|
||||
run_as_future=True,
|
||||
revision=self.args.hub_revision,
|
||||
)
|
||||
push_jobs.append(checkpoint_push)
|
||||
|
||||
@ -4882,8 +4884,12 @@ class Trainer:
|
||||
|
||||
self.create_model_card(model_name=model_name, **kwargs)
|
||||
|
||||
if revision is None:
|
||||
revision = self.args.hub_revision
|
||||
|
||||
# Wait for the current upload to be finished.
|
||||
self._finish_current_push()
|
||||
|
||||
return upload_folder(
|
||||
repo_id=self.hub_model_id,
|
||||
folder_path=self.args.output_dir,
|
||||
|
||||
@ -693,6 +693,8 @@ class TrainingArguments:
|
||||
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
|
||||
hub_always_push (`bool`, *optional*, defaults to `False`):
|
||||
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
|
||||
hub_revision (`str`, *optional*):
|
||||
The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
|
||||
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
|
||||
@ -1361,6 +1363,12 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."},
|
||||
)
|
||||
hub_revision: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash."
|
||||
},
|
||||
)
|
||||
gradient_checkpointing: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -2861,6 +2869,7 @@ class TrainingArguments:
|
||||
token: Optional[str] = None,
|
||||
private_repo: Optional[bool] = None,
|
||||
always_push: bool = False,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
A method that regroups all arguments linked to synchronizing checkpoints with the Hub.
|
||||
@ -2904,6 +2913,8 @@ class TrainingArguments:
|
||||
always_push (`bool`, *optional*, defaults to `False`):
|
||||
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not
|
||||
finished.
|
||||
revision (`str`, *optional*):
|
||||
The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
|
||||
|
||||
Example:
|
||||
|
||||
@ -2922,6 +2933,7 @@ class TrainingArguments:
|
||||
self.hub_token = token
|
||||
self.hub_private_repo = private_repo
|
||||
self.hub_always_push = always_push
|
||||
self.hub_revision = revision
|
||||
return self
|
||||
|
||||
def set_optimizer(
|
||||
|
||||
Reference in New Issue
Block a user