mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[GHF] Refactors (#116446)
Prep change for allowing stacked reverts This is a no-op that factors out some helper function that would be useful later: - `get_pr_commit_sha` finds a committed sha for a given PR - `_revlist_to_prs` converts a revlist to GitHubPRs conditionally filtering some out - `do_revert_prs` reverts multiple PRs in a batch, but so far is invoked with only one PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/116446 Approved by: https://github.com/huydhn, https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
85628c0e57
commit
5fcc2519f5
178
.github/scripts/trymerge.py
vendored
178
.github/scripts/trymerge.py
vendored
@ -20,7 +20,18 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Pattern, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
)
|
||||
from warnings import warn
|
||||
|
||||
import yaml
|
||||
@ -612,19 +623,14 @@ def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -
|
||||
return comment.author_login == "facebook-github-bot"
|
||||
|
||||
|
||||
def get_ghstack_prs(
|
||||
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
|
||||
def _revlist_to_prs(
|
||||
repo: GitRepo,
|
||||
pr: "GitHubPR",
|
||||
rev_list: Iterable[str],
|
||||
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
|
||||
) -> List[Tuple["GitHubPR", str]]:
|
||||
"""
|
||||
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
|
||||
@:param open_only: Only return open PRs
|
||||
"""
|
||||
assert pr.is_ghstack_pr()
|
||||
entire_stack: List[Tuple[GitHubPR, str]] = []
|
||||
# For ghstack, cherry-pick commits based from origin
|
||||
orig_ref = f"{repo.remote}/{re.sub(r'/head$', '/orig', pr.head_ref())}"
|
||||
rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
|
||||
for idx, rev in enumerate(reversed(rev_list)):
|
||||
rc: List[Tuple[GitHubPR, str]] = []
|
||||
for idx, rev in enumerate(rev_list):
|
||||
msg = repo.commit_message(rev)
|
||||
m = RE_PULL_REQUEST_RESOLVED.search(msg)
|
||||
if m is None:
|
||||
@ -635,17 +641,35 @@ def get_ghstack_prs(
|
||||
raise RuntimeError(
|
||||
f"PR {m.group('number')} resolved to wrong owner/repo pair"
|
||||
)
|
||||
stacked_pr_num = int(m.group("number"))
|
||||
if stacked_pr_num != pr.pr_num:
|
||||
stacked_pr = GitHubPR(pr.org, pr.project, stacked_pr_num)
|
||||
if open_only and stacked_pr.is_closed():
|
||||
print(
|
||||
f"Skipping {idx+1} of {len(rev_list)} PR (#{stacked_pr_num}) as its already been merged"
|
||||
)
|
||||
continue
|
||||
entire_stack.append((stacked_pr, rev))
|
||||
else:
|
||||
entire_stack.append((pr, rev))
|
||||
pr_num = int(m.group("number"))
|
||||
candidate = GitHubPR(pr.org, pr.project, pr_num) if pr_num != pr.pr_num else pr
|
||||
if should_skip is not None and should_skip(idx, candidate):
|
||||
continue
|
||||
rc.append((candidate, rev))
|
||||
return rc
|
||||
|
||||
|
||||
def get_ghstack_prs(
|
||||
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
|
||||
) -> List[Tuple["GitHubPR", str]]:
|
||||
"""
|
||||
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
|
||||
@:param open_only: Only return open PRs
|
||||
"""
|
||||
# For ghstack, cherry-pick commits based from origin
|
||||
orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
|
||||
rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
|
||||
|
||||
def skip_func(idx: int, candidate: "GitHubPR") -> bool:
|
||||
if not open_only or not candidate.is_closed():
|
||||
return False
|
||||
print(
|
||||
f"Skipping {idx+1} of {len(rev_list)} PR (#{candidate.pr_num}) as its already been merged"
|
||||
)
|
||||
return True
|
||||
|
||||
assert pr.is_ghstack_pr()
|
||||
entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func)
|
||||
|
||||
for stacked_pr, rev in entire_stack:
|
||||
if stacked_pr.is_closed():
|
||||
@ -694,6 +718,10 @@ class GitHubPR:
|
||||
def is_ghstack_pr(self) -> bool:
|
||||
return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
|
||||
|
||||
def get_ghstack_orig_ref(self) -> str:
|
||||
assert self.is_ghstack_pr()
|
||||
return re.sub(r"/head$", "/orig", self.head_ref())
|
||||
|
||||
def is_base_repo_private(self) -> bool:
|
||||
return bool(self.info["baseRepository"]["isPrivate"])
|
||||
|
||||
@ -1732,6 +1760,16 @@ def filter_checks_with_lambda(
|
||||
return [check for check in checks.values() if status_filter(check.status)]
|
||||
|
||||
|
||||
def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
|
||||
commit_sha = pr.get_merge_commit()
|
||||
if commit_sha is not None:
|
||||
return commit_sha
|
||||
commits = repo.commits_resolving_gh_pr(pr.pr_num)
|
||||
if len(commits) == 0:
|
||||
raise PostCommentError("Can't find any commits resolving PR")
|
||||
return commits[0]
|
||||
|
||||
|
||||
def validate_revert(
|
||||
repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
|
||||
) -> Tuple[str, str]:
|
||||
@ -1758,15 +1796,54 @@ def validate_revert(
|
||||
find_matching_merge_rule(
|
||||
pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
|
||||
)
|
||||
commit_sha = pr.get_merge_commit()
|
||||
if commit_sha is None:
|
||||
commits = repo.commits_resolving_gh_pr(pr.pr_num)
|
||||
if len(commits) == 0:
|
||||
raise PostCommentError("Can't find any commits resolving PR")
|
||||
commit_sha = commits[0]
|
||||
commit_sha = get_pr_commit_sha(repo, pr)
|
||||
return (author_login, commit_sha)
|
||||
|
||||
|
||||
def do_revert_prs(
|
||||
repo: GitRepo,
|
||||
shas_and_prs: List[Tuple[str, GitHubPR]],
|
||||
*,
|
||||
author_login: str,
|
||||
extra_msg: str = "",
|
||||
skip_internal_checks: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
# Prepare and push revert commits
|
||||
commit_shas: List[str] = []
|
||||
for commit_sha, pr in shas_and_prs:
|
||||
revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
|
||||
revert_msg += extra_msg
|
||||
repo.checkout(pr.default_branch())
|
||||
repo.revert(commit_sha)
|
||||
msg = repo.commit_message("HEAD")
|
||||
msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
|
||||
msg += revert_msg
|
||||
repo.amend_commit_message(msg)
|
||||
repo.push(shas_and_prs[0][1].default_branch(), dry_run)
|
||||
|
||||
# Comment/reopen PRs
|
||||
for commit_sha, pr in shas_and_prs:
|
||||
revert_message = (
|
||||
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
|
||||
)
|
||||
if (
|
||||
pr.has_internal_changes()
|
||||
and not pr.has_no_connected_diff()
|
||||
and not skip_internal_checks
|
||||
):
|
||||
revert_message += "\n:warning: This PR might contain internal changes"
|
||||
revert_message += "\ncc: @pytorch/pytorch-dev-infra"
|
||||
gh_post_pr_comment(
|
||||
pr.org, pr.project, pr.pr_num, revert_message, dry_run=dry_run
|
||||
)
|
||||
|
||||
if not dry_run:
|
||||
pr.add_numbered_label("reverted")
|
||||
gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
|
||||
gh_update_pr_state(pr.org, pr.project, pr.pr_num)
|
||||
|
||||
|
||||
def try_revert(
|
||||
repo: GitRepo,
|
||||
pr: GitHubPR,
|
||||
@ -1775,45 +1852,26 @@ def try_revert(
|
||||
comment_id: Optional[int] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
def post_comment(msg: str) -> None:
|
||||
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, msg, dry_run=dry_run)
|
||||
|
||||
try:
|
||||
author_login, commit_sha = validate_revert(repo, pr, comment_id=comment_id)
|
||||
except PostCommentError as e:
|
||||
return post_comment(str(e))
|
||||
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, str(e), dry_run=dry_run)
|
||||
return
|
||||
|
||||
revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
|
||||
revert_msg += f" due to {reason}" if reason is not None else ""
|
||||
revert_msg += (
|
||||
extra_msg = f" due to {reason}" if reason is not None else ""
|
||||
extra_msg += (
|
||||
f" ([comment]({pr.get_comment_by_id(comment_id).url}))\n"
|
||||
if comment_id is not None
|
||||
else "\n"
|
||||
)
|
||||
repo.checkout(pr.default_branch())
|
||||
repo.revert(commit_sha)
|
||||
msg = repo.commit_message("HEAD")
|
||||
msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
|
||||
msg += revert_msg
|
||||
repo.amend_commit_message(msg)
|
||||
repo.push(pr.default_branch(), dry_run)
|
||||
|
||||
revert_message = (
|
||||
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
|
||||
do_revert_prs(
|
||||
repo,
|
||||
[(commit_sha, pr)],
|
||||
author_login=author_login,
|
||||
extra_msg=extra_msg,
|
||||
dry_run=dry_run,
|
||||
skip_internal_checks=can_skip_internal_checks(pr, comment_id),
|
||||
)
|
||||
if (
|
||||
pr.has_internal_changes()
|
||||
and not pr.has_no_connected_diff()
|
||||
and not can_skip_internal_checks(pr, comment_id)
|
||||
):
|
||||
revert_message += "\n:warning: This PR might contain internal changes"
|
||||
revert_message += "\ncc: @pytorch/pytorch-dev-infra"
|
||||
post_comment(revert_message)
|
||||
|
||||
if not dry_run:
|
||||
pr.add_numbered_label("reverted")
|
||||
gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
|
||||
gh_update_pr_state(pr.org, pr.project, pr.pr_num)
|
||||
|
||||
|
||||
def prefix_with_github_url(suffix_str: str) -> str:
|
||||
|
Reference in New Issue
Block a user