[GHF] Refactors (#116446)

Prep change for allowing stacked reverts

This is a no-op that factors out some helper function that would be
useful later:
 - `get_pr_commit_sha` finds a committed sha for a given PR
 - `_revlist_to_prs` converts a revlist to GitHubPRs conditionally
   filtering some out
 - `do_revert_prs` reverts multiple PRs in a batch, but so far is
   invoked with only one PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116446
Approved by: https://github.com/huydhn, https://github.com/seemethere
This commit is contained in:
Nikita Shulga
2023-12-27 10:30:26 -08:00
committed by PyTorch MergeBot
parent 85628c0e57
commit 5fcc2519f5
2 changed files with 128 additions and 60 deletions

View File

@ -20,7 +20,18 @@ from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Pattern, Tuple
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Pattern,
Tuple,
)
from warnings import warn
import yaml
@ -612,19 +623,14 @@ def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -
return comment.author_login == "facebook-github-bot"
def get_ghstack_prs(
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
def _revlist_to_prs(
repo: GitRepo,
pr: "GitHubPR",
rev_list: Iterable[str],
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
) -> List[Tuple["GitHubPR", str]]:
"""
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
@:param open_only: Only return open PRs
"""
assert pr.is_ghstack_pr()
entire_stack: List[Tuple[GitHubPR, str]] = []
# For ghstack, cherry-pick commits based from origin
orig_ref = f"{repo.remote}/{re.sub(r'/head$', '/orig', pr.head_ref())}"
rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
for idx, rev in enumerate(reversed(rev_list)):
rc: List[Tuple[GitHubPR, str]] = []
for idx, rev in enumerate(rev_list):
msg = repo.commit_message(rev)
m = RE_PULL_REQUEST_RESOLVED.search(msg)
if m is None:
@ -635,17 +641,35 @@ def get_ghstack_prs(
raise RuntimeError(
f"PR {m.group('number')} resolved to wrong owner/repo pair"
)
stacked_pr_num = int(m.group("number"))
if stacked_pr_num != pr.pr_num:
stacked_pr = GitHubPR(pr.org, pr.project, stacked_pr_num)
if open_only and stacked_pr.is_closed():
print(
f"Skipping {idx+1} of {len(rev_list)} PR (#{stacked_pr_num}) as its already been merged"
)
continue
entire_stack.append((stacked_pr, rev))
else:
entire_stack.append((pr, rev))
pr_num = int(m.group("number"))
candidate = GitHubPR(pr.org, pr.project, pr_num) if pr_num != pr.pr_num else pr
if should_skip is not None and should_skip(idx, candidate):
continue
rc.append((candidate, rev))
return rc
def get_ghstack_prs(
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
) -> List[Tuple["GitHubPR", str]]:
"""
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
@:param open_only: Only return open PRs
"""
# For ghstack, cherry-pick commits based from origin
orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
def skip_func(idx: int, candidate: "GitHubPR") -> bool:
if not open_only or not candidate.is_closed():
return False
print(
f"Skipping {idx+1} of {len(rev_list)} PR (#{candidate.pr_num}) as its already been merged"
)
return True
assert pr.is_ghstack_pr()
entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func)
for stacked_pr, rev in entire_stack:
if stacked_pr.is_closed():
@ -694,6 +718,10 @@ class GitHubPR:
def is_ghstack_pr(self) -> bool:
return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
def get_ghstack_orig_ref(self) -> str:
assert self.is_ghstack_pr()
return re.sub(r"/head$", "/orig", self.head_ref())
def is_base_repo_private(self) -> bool:
return bool(self.info["baseRepository"]["isPrivate"])
@ -1732,6 +1760,16 @@ def filter_checks_with_lambda(
return [check for check in checks.values() if status_filter(check.status)]
def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
commit_sha = pr.get_merge_commit()
if commit_sha is not None:
return commit_sha
commits = repo.commits_resolving_gh_pr(pr.pr_num)
if len(commits) == 0:
raise PostCommentError("Can't find any commits resolving PR")
return commits[0]
def validate_revert(
repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
) -> Tuple[str, str]:
@ -1758,15 +1796,54 @@ def validate_revert(
find_matching_merge_rule(
pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
)
commit_sha = pr.get_merge_commit()
if commit_sha is None:
commits = repo.commits_resolving_gh_pr(pr.pr_num)
if len(commits) == 0:
raise PostCommentError("Can't find any commits resolving PR")
commit_sha = commits[0]
commit_sha = get_pr_commit_sha(repo, pr)
return (author_login, commit_sha)
def do_revert_prs(
repo: GitRepo,
shas_and_prs: List[Tuple[str, GitHubPR]],
*,
author_login: str,
extra_msg: str = "",
skip_internal_checks: bool = False,
dry_run: bool = False,
) -> None:
# Prepare and push revert commits
commit_shas: List[str] = []
for commit_sha, pr in shas_and_prs:
revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
revert_msg += extra_msg
repo.checkout(pr.default_branch())
repo.revert(commit_sha)
msg = repo.commit_message("HEAD")
msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
msg += revert_msg
repo.amend_commit_message(msg)
repo.push(shas_and_prs[0][1].default_branch(), dry_run)
# Comment/reopen PRs
for commit_sha, pr in shas_and_prs:
revert_message = (
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
)
if (
pr.has_internal_changes()
and not pr.has_no_connected_diff()
and not skip_internal_checks
):
revert_message += "\n:warning: This PR might contain internal changes"
revert_message += "\ncc: @pytorch/pytorch-dev-infra"
gh_post_pr_comment(
pr.org, pr.project, pr.pr_num, revert_message, dry_run=dry_run
)
if not dry_run:
pr.add_numbered_label("reverted")
gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
gh_update_pr_state(pr.org, pr.project, pr.pr_num)
def try_revert(
repo: GitRepo,
pr: GitHubPR,
@ -1775,45 +1852,26 @@ def try_revert(
comment_id: Optional[int] = None,
reason: Optional[str] = None,
) -> None:
def post_comment(msg: str) -> None:
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, msg, dry_run=dry_run)
try:
author_login, commit_sha = validate_revert(repo, pr, comment_id=comment_id)
except PostCommentError as e:
return post_comment(str(e))
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, str(e), dry_run=dry_run)
return
revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
revert_msg += f" due to {reason}" if reason is not None else ""
revert_msg += (
extra_msg = f" due to {reason}" if reason is not None else ""
extra_msg += (
f" ([comment]({pr.get_comment_by_id(comment_id).url}))\n"
if comment_id is not None
else "\n"
)
repo.checkout(pr.default_branch())
repo.revert(commit_sha)
msg = repo.commit_message("HEAD")
msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
msg += revert_msg
repo.amend_commit_message(msg)
repo.push(pr.default_branch(), dry_run)
revert_message = (
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
do_revert_prs(
repo,
[(commit_sha, pr)],
author_login=author_login,
extra_msg=extra_msg,
dry_run=dry_run,
skip_internal_checks=can_skip_internal_checks(pr, comment_id),
)
if (
pr.has_internal_changes()
and not pr.has_no_connected_diff()
and not can_skip_internal_checks(pr, comment_id)
):
revert_message += "\n:warning: This PR might contain internal changes"
revert_message += "\ncc: @pytorch/pytorch-dev-infra"
post_comment(revert_message)
if not dry_run:
pr.add_numbered_label("reverted")
gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
gh_update_pr_state(pr.org, pr.project, pr.pr_num)
def prefix_with_github_url(suffix_str: str) -> str: