[BE] Refactor trymerge for readability (#161637)

Two changes:
- Extract getting the last_commit's sha into it's own function
- Rename merge_changes to merge_changes_locally to better explain it's functionality
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161637
Approved by: https://github.com/seemethere, https://github.com/malfet
ghstack dependencies: #161558
This commit is contained in:
Zain Rizvi
2025-08-27 16:46:27 -05:00
committed by PyTorch MergeBot
parent ee0ec21191
commit 6b051d7de3

View File

@ -737,16 +737,24 @@ class GitHubPR:
def last_commit(self) -> Any:
return self.info["commits"]["nodes"][-1]["commit"]
def last_commit_sha(self, default: Optional[str] = None) -> str:
# for commits, the oid is the sha
if default is None:
return str(self.last_commit()["oid"])
return str(self.last_commit().get("oid", default))
def get_merge_base(self) -> str:
if self.merge_base:
return self.merge_base
last_commit_oid = self.last_commit()["oid"]
last_commit_sha = self.last_commit_sha()
# NB: We could use self.base_ref() here for regular PR, however, that doesn't
# work for ghstack where the base is the custom branch, i.e. gh/USER/ID/base,
# so let's just use main instead
self.merge_base = gh_fetch_merge_base(
self.org, self.project, last_commit_oid, self.default_branch()
self.org, self.project, last_commit_sha, self.default_branch()
)
# Fallback to baseRefOid if the API call fails, i.e. rate limit. Note that baseRefOid
@ -1167,7 +1175,7 @@ class GitHubPR:
skip_internal_checks=can_skip_internal_checks(self, comment_id),
ignore_current_checks=ignore_current_checks,
)
additional_merged_prs = self.merge_changes(
additional_merged_prs = self.merge_changes_locally(
repo, skip_mandatory_checks, comment_id
)
@ -1196,7 +1204,7 @@ class GitHubPR:
broken_trunk_checks=ignorable_checks.get("BROKEN_TRUNK", []),
flaky_checks=ignorable_checks.get("FLAKY", []),
unstable_checks=ignorable_checks.get("UNSTABLE", []),
last_commit_sha=self.last_commit().get("oid", ""),
last_commit_sha=self.last_commit_sha(default=""),
merge_base_sha=self.get_merge_base(),
merge_commit_sha=merge_commit_sha,
is_failed=False,
@ -1217,7 +1225,7 @@ class GitHubPR:
dry_run=dry_run,
)
def merge_changes(
def merge_changes_locally(
self,
repo: GitRepo,
skip_mandatory_checks: bool = False,
@ -1241,14 +1249,14 @@ class GitHubPR:
msg = self.gen_commit_message()
pr_branch_name = f"__pull-request-{self.pr_num}__init__"
repo.fetch(self.last_commit()["oid"], pr_branch_name)
repo.fetch(self.last_commit_sha(), pr_branch_name)
repo._run_git("merge", "--squash", pr_branch_name)
repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg)
# Did the PR change since we started the merge?
pulled_sha = repo.show_ref(pr_branch_name)
latest_pr_status = GitHubPR(self.org, self.project, self.pr_num)
if pulled_sha != latest_pr_status.last_commit()["oid"]:
if pulled_sha != latest_pr_status.last_commit_sha():
raise RuntimeError(
"PR has been updated since CI checks last passed. Please rerun the merge command."
)
@ -1458,7 +1466,7 @@ def find_matching_merge_rule(
pending_checks = []
failed_checks = []
hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit()['oid']}"
hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit_sha()}"
if len(failed_checks) > 0:
if reject_reason_score < 30000:
reject_reason_score = 30000
@ -2163,7 +2171,7 @@ def merge(
stale_pr_days: int = 3,
ignore_current: bool = False,
) -> None:
initial_commit_sha = pr.last_commit()["oid"]
initial_commit_sha = pr.last_commit_sha()
pr_link = f"https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num}"
print(f"Attempting merge of {initial_commit_sha} ({pr_link})")
@ -2234,7 +2242,7 @@ def merge(
f"Attempting merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} ({elapsed_time / 60} minutes elapsed)"
)
pr = GitHubPR(pr.org, pr.project, pr.pr_num)
if initial_commit_sha != pr.last_commit()["oid"]:
if initial_commit_sha != pr.last_commit_sha():
raise RuntimeError(
"New commits were pushed while merging. Please rerun the merge command."
)
@ -2401,7 +2409,7 @@ def main() -> None:
if args.check_mergeability:
if pr.is_ghstack_pr():
get_ghstack_prs(repo, pr) # raises error if out of sync
pr.merge_changes(
pr.merge_changes_locally(
repo,
skip_mandatory_checks=True,
skip_all_rule_checks=True,
@ -2449,7 +2457,7 @@ def main() -> None:
broken_trunk_checks=[],
flaky_checks=[],
unstable_checks=[],
last_commit_sha=pr.last_commit().get("oid", ""),
last_commit_sha=pr.last_commit_sha(default=""),
merge_base_sha=pr.get_merge_base(),
is_failed=True,
skip_mandatory_checks=args.force,