mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[mergebot] Create land time check options (#77943)
This adds land time checks before we try to merge. What this does is: 1. Merge changes into latest master, check out a new branch, push, and have a workflow that runs jobs from trunk (and maybe pull) 2. Wait for all checks in the landtime workflow to finish by using the GH API (graphql doesn't have this method from what I can see) 3. push the changes in Test Plan: Tested this in canary with a new workflow that passes and lint, tested what happens if i break the new workflow by exiting with 1, the normal flow, and some other flows. Tested it breaking when land checks fail: https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1165941716 Test that it works: https://github.com/pytorch/pytorch-canary/pull/114#issuecomment-1165922791 Test that normal validations like PR is broken: https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1165930037 Test that normal merge works: https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1166751288 Test that force merge works: https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1167507356 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77943 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
79c2dfcd8e
commit
2acd2317b8
123
.github/scripts/trymerge.py
vendored
123
.github/scripts/trymerge.py
vendored
@ -10,7 +10,7 @@ from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.error import HTTPError
|
||||
from typing import cast, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Iterable, cast, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, patterns_to_regex, GitRepo
|
||||
from functools import lru_cache
|
||||
from warnings import warn
|
||||
@ -312,7 +312,6 @@ RE_REVERT_CMD = re.compile(r"@pytorch(merge|)bot\s+revert\s+this")
|
||||
RE_REVERT_CMD_CLI = re.compile(r"@pytorch(merge|)bot\s+revert\s+(-m.*-c.*|-c.*-m.*)")
|
||||
RE_DIFF_REV = re.compile(r'^Differential Revision:.+?(D[0-9]+)', re.MULTILINE)
|
||||
|
||||
|
||||
def _fetch_url(url: str, *,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
@ -332,7 +331,6 @@ def _fetch_url(url: str, *,
|
||||
print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}")
|
||||
raise
|
||||
|
||||
|
||||
def fetch_json(url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
@ -341,6 +339,13 @@ def fetch_json(url: str,
|
||||
url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items())
|
||||
return cast(List[Dict[str, Any]], _fetch_url(url, headers=headers, data=data, reader=json.load))
|
||||
|
||||
def fetch_json_dict(url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None) -> Dict[str, Any] :
|
||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||
if params is not None and len(params) > 0:
|
||||
url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items())
|
||||
return cast(Dict[str, Any], _fetch_url(url, headers=headers, data=data, reader=json.load))
|
||||
|
||||
def _gh_post_comment(url: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
|
||||
if dry_run:
|
||||
@ -395,6 +400,7 @@ def parse_args() -> Any:
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--on-green", action="store_true")
|
||||
parser.add_argument("--on-mandatory", action="store_true")
|
||||
parser.add_argument("--land-checks", action="store_true")
|
||||
parser.add_argument("--revert", action="store_true")
|
||||
parser.add_argument("--force", action="store_true")
|
||||
parser.add_argument("--comment-id", type=int)
|
||||
@ -733,11 +739,28 @@ class GitHubPR:
|
||||
msg += f"Approved by: {approved_by_urls}\n"
|
||||
return msg
|
||||
|
||||
def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = False, comment_id: Optional[int] = None) -> None:
|
||||
def merge_into(self, repo: GitRepo, *,
|
||||
force: bool = False,
|
||||
dry_run: bool = False,
|
||||
comment_id: Optional[int] = None) -> None:
|
||||
# Raises exception if matching rule is not found
|
||||
find_matching_merge_rule(self, repo, force=force, skip_internal_checks=can_skip_internal_checks(self, comment_id))
|
||||
if repo.current_branch() != self.default_branch():
|
||||
repo.checkout(self.default_branch())
|
||||
self.merge_changes(repo, force, comment_id)
|
||||
|
||||
repo.push(self.default_branch(), dry_run)
|
||||
gh_post_pr_comment(self.org, self.project, self.pr_num,
|
||||
f"@{self.get_pr_creator_login()} your PR has been successfully merged.", dry_run)
|
||||
if not dry_run:
|
||||
gh_add_labels(self.org, self.project, self.pr_num, ["merged"])
|
||||
|
||||
def merge_changes(self,
|
||||
repo: GitRepo,
|
||||
force: bool = False,
|
||||
comment_id: Optional[int] = None,
|
||||
branch: Optional[str] = None) -> None:
|
||||
branch_to_merge_into = self.default_branch() if branch is None else branch
|
||||
if repo.current_branch() != branch_to_merge_into:
|
||||
repo.checkout(branch_to_merge_into)
|
||||
if not self.is_ghstack_pr():
|
||||
msg = self.gen_commit_message()
|
||||
pr_branch_name = f"__pull-request-{self.pr_num}__init__"
|
||||
@ -747,9 +770,24 @@ class GitHubPR:
|
||||
else:
|
||||
self.merge_ghstack_into(repo, force, comment_id=comment_id)
|
||||
|
||||
repo.push(self.default_branch(), dry_run)
|
||||
if not dry_run:
|
||||
gh_add_labels(self.org, self.project, self.pr_num, ["merged"])
|
||||
def create_land_time_check_branch(self,
|
||||
repo: GitRepo,
|
||||
branch: str,
|
||||
force: bool = False,
|
||||
comment_id: Optional[int] = None,) -> str:
|
||||
self.merge_changes(repo, branch=branch, force=force, comment_id=comment_id)
|
||||
land_check_branch = f'landchecks/{self.pr_num}'
|
||||
try:
|
||||
repo._run_git('branch', "-D", land_check_branch)
|
||||
except Exception:
|
||||
pass
|
||||
repo._run_git('checkout', "-b", land_check_branch)
|
||||
repo._run_git('push', '-u', 'origin', land_check_branch, '--force')
|
||||
commit = repo.get_commit('HEAD').commit_hash
|
||||
gh_post_pr_comment(self.org, self.project, self.pr_num,
|
||||
'Successfully started land time checks.' +
|
||||
f' See progress here: https://hud.pytorch.org/{self.org}/{self.project}/commit/{commit}')
|
||||
return commit
|
||||
|
||||
|
||||
class MandatoryChecksMissingError(Exception):
|
||||
@ -838,21 +876,10 @@ def find_matching_merge_rule(pr: GitHubPR,
|
||||
reject_reason = (f"Matched rule {rule_name}, but PR #{pr.pr_num} was not reviewed yet by any of: " +
|
||||
f"{', '.join(list(rule_approvers_set)[:5])}{', ...' if len(rule_approvers_set) > 5 else ''}")
|
||||
continue
|
||||
if rule.mandatory_checks_name is not None:
|
||||
pending_checks: List[Tuple[str, Optional[str]]] = []
|
||||
failed_checks: List[Tuple[str, Optional[str]]] = []
|
||||
checks = pr.get_checkrun_conclusions()
|
||||
# HACK: We don't want to skip CLA check, even when forced
|
||||
for checkname in filter(lambda x: force is False or "CLA Check" in x, rule.mandatory_checks_name):
|
||||
if checkname not in checks:
|
||||
pending_checks.append((checkname, None))
|
||||
elif checks[checkname][0] is None:
|
||||
pending_checks.append((checkname, checks[checkname][1]))
|
||||
elif checks[checkname][0] != 'SUCCESS':
|
||||
failed_checks.append((checkname, checks[checkname][1]))
|
||||
|
||||
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
|
||||
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
|
||||
mandatory_checks = rule.mandatory_checks_name if rule.mandatory_checks_name is not None else []
|
||||
checks = pr.get_checkrun_conclusions()
|
||||
required_checks = filter(lambda x: force is False or "CLA Check" in x, mandatory_checks)
|
||||
[pending_checks, failed_checks] = categorize_checks(checks, required_checks)
|
||||
|
||||
if len(failed_checks) > 0:
|
||||
if reject_reason_score < 30000:
|
||||
@ -874,6 +901,9 @@ def find_matching_merge_rule(pr: GitHubPR,
|
||||
raise RuntimeError(reject_reason)
|
||||
|
||||
|
||||
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
|
||||
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
|
||||
|
||||
def pr_get_checks_with_lambda(pr: GitHubPR, status_check: Callable[[Optional[str]], bool]) -> List[Tuple[str, str]]:
|
||||
checks = pr.get_checkrun_conclusions()
|
||||
return [(name, status[1]) for name, status in checks.items() if status_check(status[0])]
|
||||
@ -938,7 +968,6 @@ def try_revert(repo: GitRepo, pr: GitHubPR, *,
|
||||
def prefix_with_github_url(suffix_str: str) -> str:
|
||||
return f"https://github.com/{suffix_str}"
|
||||
|
||||
|
||||
def check_for_sev(org: str, project: str, force: bool) -> None:
|
||||
if force:
|
||||
return
|
||||
@ -959,6 +988,38 @@ def check_for_sev(org: str, project: str, force: bool) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
def fetch_check_run_conclusions(repo: GitRepo, commit: str) -> Dict[str, Tuple[str, str]]:
|
||||
[owner, name] = repo.gh_owner_and_name()
|
||||
checks = fetch_json_dict(f'https://api.github.com/repos/{owner}/{name}/commits/{commit}/check-runs')
|
||||
check_run_conclusions = {}
|
||||
if len(checks) == 0:
|
||||
raise MandatoryChecksMissingError("Refusing to merge as land check(s) are not yet run")
|
||||
for check_run in checks['check_runs']:
|
||||
check_run_conclusions[check_run['name']] = (check_run['conclusion'],
|
||||
check_run['html_url'])
|
||||
return check_run_conclusions
|
||||
|
||||
def validate_land_time_checks(repo: GitRepo, commit: str) -> None:
|
||||
checks = fetch_check_run_conclusions(repo, commit)
|
||||
[pending_checks, failed_checks] = categorize_checks(checks, checks)
|
||||
|
||||
if len(failed_checks) > 0:
|
||||
raise RuntimeError(f"Failed to merge; some land checks failed: {checks_to_str(failed_checks)}")
|
||||
if len(pending_checks) > 0:
|
||||
raise MandatoryChecksMissingError(f"Refusing to merge as land check(s) {checks_to_str(pending_checks)} are not yet run")
|
||||
|
||||
def categorize_checks(check_runs: Dict[str, Tuple[str, str]],
|
||||
required_checks: Iterable[str]) -> Tuple[List[Tuple[str, Optional[str]]], List[Tuple[str, Optional[str]]]]:
|
||||
pending_checks: List[Tuple[str, Optional[str]]] = []
|
||||
failed_checks: List[Tuple[str, Optional[str]]] = []
|
||||
for checkname in required_checks:
|
||||
if checkname not in check_runs:
|
||||
pending_checks.append((checkname, None))
|
||||
elif check_runs[checkname][0] is None:
|
||||
pending_checks.append((checkname, check_runs[checkname][1]))
|
||||
elif check_runs[checkname][0].upper() != 'SUCCESS' and check_runs[checkname][0].upper() != 'SKIPPED':
|
||||
failed_checks.append((checkname, check_runs[checkname][1]))
|
||||
return (pending_checks, failed_checks)
|
||||
|
||||
def merge(pr_num: int, repo: GitRepo,
|
||||
dry_run: bool = False,
|
||||
@ -966,6 +1027,7 @@ def merge(pr_num: int, repo: GitRepo,
|
||||
comment_id: Optional[int] = None,
|
||||
mandatory_only: bool = False,
|
||||
on_green: bool = False,
|
||||
land_checks: bool = False,
|
||||
timeout_minutes: int = 400,
|
||||
stale_pr_days: int = 3) -> None:
|
||||
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
||||
@ -979,6 +1041,9 @@ def merge(pr_num: int, repo: GitRepo,
|
||||
if (datetime.utcnow() - pr.last_pushed_at()).days > stale_pr_days:
|
||||
raise RuntimeError("This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again.")
|
||||
|
||||
if land_checks:
|
||||
commit = pr.create_land_time_check_branch(repo, 'viable/strict', force=force, comment_id=comment_id)
|
||||
|
||||
start_time = time.time()
|
||||
last_exception = ''
|
||||
elapsed_time = 0.0
|
||||
@ -994,12 +1059,16 @@ def merge(pr_num: int, repo: GitRepo,
|
||||
find_matching_merge_rule(pr, repo)
|
||||
pending = pr_get_pending_checks(pr)
|
||||
failing = pr_get_failed_checks(pr)
|
||||
|
||||
if (not mandatory_only and on_green) and len(failing) > 0:
|
||||
raise RuntimeError(f"{len(failing)} additional jobs have failed, first few of them are: " +
|
||||
' ,'.join(f"[{x[0]}]({x[1]})" for x in failing[:5]))
|
||||
if (not mandatory_only and on_green) and len(pending) > 0:
|
||||
raise MandatoryChecksMissingError(f"Still waiting for {len(pending)} additional jobs to finish, " +
|
||||
f"first few of them are: {' ,'.join(x[0] for x in pending[:5])}")
|
||||
if land_checks:
|
||||
validate_land_time_checks(repo, commit)
|
||||
|
||||
return pr.merge_into(repo, dry_run=dry_run, force=force, comment_id=comment_id)
|
||||
except MandatoryChecksMissingError as ex:
|
||||
last_exception = str(ex)
|
||||
@ -1052,11 +1121,11 @@ def main() -> None:
|
||||
force=args.force,
|
||||
comment_id=args.comment_id,
|
||||
on_green=args.on_green,
|
||||
mandatory_only=args.on_mandatory)
|
||||
mandatory_only=args.on_mandatory,
|
||||
land_checks=args.land_checks)
|
||||
except Exception as e:
|
||||
handle_exception(e)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user