Add support for specifying revisions when pushing to Hub via internal Trainer call (#36852)

* Update training_args.py

* Update trainer.py

* fixes

* fix

* remove extraneous comments

* explicit revision arg

* add msg

* fixup

* fix field name

* rename field revision to hub_revision

* restore gradient_checkpointing doc

* fix ws

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Isaac Breen
2025-06-19 10:35:33 +08:00
committed by GitHub
parent 458e0b376c
commit 3756bf192c
2 changed files with 19 additions and 1 deletions

View File

@ -3938,7 +3938,7 @@ class Trainer:
# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")
self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision)
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
@ -4788,6 +4788,7 @@ class Trainer:
token=self.args.hub_token,
run_as_future=True,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
revision=self.args.hub_revision,
)
push_jobs = [model_push_job]
@ -4803,6 +4804,7 @@ class Trainer:
commit_message=commit_message + ", checkpoint",
token=self.args.hub_token,
run_as_future=True,
revision=self.args.hub_revision,
)
push_jobs.append(checkpoint_push)
@ -4882,8 +4884,12 @@ class Trainer:
self.create_model_card(model_name=model_name, **kwargs)
if revision is None:
revision = self.args.hub_revision
# Wait for the current upload to be finished.
self._finish_current_push()
return upload_folder(
repo_id=self.hub_model_id,
folder_path=self.args.output_dir,

View File

@ -693,6 +693,8 @@ class TrainingArguments:
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
hub_always_push (`bool`, *optional*, defaults to `False`):
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
hub_revision (`str`, *optional*):
The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
@ -1361,6 +1363,12 @@ class TrainingArguments:
default=False,
metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."},
)
hub_revision: Optional[str] = field(
default=None,
metadata={
"help": "The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash."
},
)
gradient_checkpointing: bool = field(
default=False,
metadata={
@ -2861,6 +2869,7 @@ class TrainingArguments:
token: Optional[str] = None,
private_repo: Optional[bool] = None,
always_push: bool = False,
revision: Optional[str] = None,
):
"""
A method that regroups all arguments linked to synchronizing checkpoints with the Hub.
@ -2904,6 +2913,8 @@ class TrainingArguments:
always_push (`bool`, *optional*, defaults to `False`):
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not
finished.
revision (`str`, *optional*):
The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
Example:
@ -2922,6 +2933,7 @@ class TrainingArguments:
self.hub_token = token
self.hub_private_repo = private_repo
self.hub_always_push = always_push
self.hub_revision = revision
return self
def set_optimizer(