mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Refactor trymerge for readability (#161637)
Two changes: - Extract getting the last_commit's sha into it's own function - Rename merge_changes to merge_changes_locally to better explain it's functionality Pull Request resolved: https://github.com/pytorch/pytorch/pull/161637 Approved by: https://github.com/seemethere, https://github.com/malfet ghstack dependencies: #161558
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee0ec21191
commit
6b051d7de3
32
.github/scripts/trymerge.py
vendored
32
.github/scripts/trymerge.py
vendored
@ -737,16 +737,24 @@ class GitHubPR:
|
||||
def last_commit(self) -> Any:
|
||||
return self.info["commits"]["nodes"][-1]["commit"]
|
||||
|
||||
def last_commit_sha(self, default: Optional[str] = None) -> str:
|
||||
# for commits, the oid is the sha
|
||||
|
||||
if default is None:
|
||||
return str(self.last_commit()["oid"])
|
||||
|
||||
return str(self.last_commit().get("oid", default))
|
||||
|
||||
def get_merge_base(self) -> str:
|
||||
if self.merge_base:
|
||||
return self.merge_base
|
||||
|
||||
last_commit_oid = self.last_commit()["oid"]
|
||||
last_commit_sha = self.last_commit_sha()
|
||||
# NB: We could use self.base_ref() here for regular PR, however, that doesn't
|
||||
# work for ghstack where the base is the custom branch, i.e. gh/USER/ID/base,
|
||||
# so let's just use main instead
|
||||
self.merge_base = gh_fetch_merge_base(
|
||||
self.org, self.project, last_commit_oid, self.default_branch()
|
||||
self.org, self.project, last_commit_sha, self.default_branch()
|
||||
)
|
||||
|
||||
# Fallback to baseRefOid if the API call fails, i.e. rate limit. Note that baseRefOid
|
||||
@ -1167,7 +1175,7 @@ class GitHubPR:
|
||||
skip_internal_checks=can_skip_internal_checks(self, comment_id),
|
||||
ignore_current_checks=ignore_current_checks,
|
||||
)
|
||||
additional_merged_prs = self.merge_changes(
|
||||
additional_merged_prs = self.merge_changes_locally(
|
||||
repo, skip_mandatory_checks, comment_id
|
||||
)
|
||||
|
||||
@ -1196,7 +1204,7 @@ class GitHubPR:
|
||||
broken_trunk_checks=ignorable_checks.get("BROKEN_TRUNK", []),
|
||||
flaky_checks=ignorable_checks.get("FLAKY", []),
|
||||
unstable_checks=ignorable_checks.get("UNSTABLE", []),
|
||||
last_commit_sha=self.last_commit().get("oid", ""),
|
||||
last_commit_sha=self.last_commit_sha(default=""),
|
||||
merge_base_sha=self.get_merge_base(),
|
||||
merge_commit_sha=merge_commit_sha,
|
||||
is_failed=False,
|
||||
@ -1217,7 +1225,7 @@ class GitHubPR:
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
def merge_changes(
|
||||
def merge_changes_locally(
|
||||
self,
|
||||
repo: GitRepo,
|
||||
skip_mandatory_checks: bool = False,
|
||||
@ -1241,14 +1249,14 @@ class GitHubPR:
|
||||
|
||||
msg = self.gen_commit_message()
|
||||
pr_branch_name = f"__pull-request-{self.pr_num}__init__"
|
||||
repo.fetch(self.last_commit()["oid"], pr_branch_name)
|
||||
repo.fetch(self.last_commit_sha(), pr_branch_name)
|
||||
repo._run_git("merge", "--squash", pr_branch_name)
|
||||
repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg)
|
||||
|
||||
# Did the PR change since we started the merge?
|
||||
pulled_sha = repo.show_ref(pr_branch_name)
|
||||
latest_pr_status = GitHubPR(self.org, self.project, self.pr_num)
|
||||
if pulled_sha != latest_pr_status.last_commit()["oid"]:
|
||||
if pulled_sha != latest_pr_status.last_commit_sha():
|
||||
raise RuntimeError(
|
||||
"PR has been updated since CI checks last passed. Please rerun the merge command."
|
||||
)
|
||||
@ -1458,7 +1466,7 @@ def find_matching_merge_rule(
|
||||
pending_checks = []
|
||||
failed_checks = []
|
||||
|
||||
hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit()['oid']}"
|
||||
hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit_sha()}"
|
||||
if len(failed_checks) > 0:
|
||||
if reject_reason_score < 30000:
|
||||
reject_reason_score = 30000
|
||||
@ -2163,7 +2171,7 @@ def merge(
|
||||
stale_pr_days: int = 3,
|
||||
ignore_current: bool = False,
|
||||
) -> None:
|
||||
initial_commit_sha = pr.last_commit()["oid"]
|
||||
initial_commit_sha = pr.last_commit_sha()
|
||||
pr_link = f"https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num}"
|
||||
print(f"Attempting merge of {initial_commit_sha} ({pr_link})")
|
||||
|
||||
@ -2234,7 +2242,7 @@ def merge(
|
||||
f"Attempting merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} ({elapsed_time / 60} minutes elapsed)"
|
||||
)
|
||||
pr = GitHubPR(pr.org, pr.project, pr.pr_num)
|
||||
if initial_commit_sha != pr.last_commit()["oid"]:
|
||||
if initial_commit_sha != pr.last_commit_sha():
|
||||
raise RuntimeError(
|
||||
"New commits were pushed while merging. Please rerun the merge command."
|
||||
)
|
||||
@ -2401,7 +2409,7 @@ def main() -> None:
|
||||
if args.check_mergeability:
|
||||
if pr.is_ghstack_pr():
|
||||
get_ghstack_prs(repo, pr) # raises error if out of sync
|
||||
pr.merge_changes(
|
||||
pr.merge_changes_locally(
|
||||
repo,
|
||||
skip_mandatory_checks=True,
|
||||
skip_all_rule_checks=True,
|
||||
@ -2449,7 +2457,7 @@ def main() -> None:
|
||||
broken_trunk_checks=[],
|
||||
flaky_checks=[],
|
||||
unstable_checks=[],
|
||||
last_commit_sha=pr.last_commit().get("oid", ""),
|
||||
last_commit_sha=pr.last_commit_sha(default=""),
|
||||
merge_base_sha=pr.get_merge_base(),
|
||||
is_failed=True,
|
||||
skip_mandatory_checks=args.force,
|
||||
|
Reference in New Issue
Block a user