[mergebot] Create land time check options (#77943)

This adds land time checks before we try to merge. What this does is:
1. Merge changes into latest master, check out a new branch, push, and have a workflow that runs jobs from trunk (and maybe pull)
2. Wait for all checks in the landtime workflow to finish by using the GH API (graphql doesn't have this method from what I can see)
3. push the changes in

Test Plan:
Tested this in canary with a new workflow that passes and lint, tested what happens if i break the new workflow by exiting with 1, the normal flow, and some other flows.

Tested it breaking when land checks fail:
https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1165941716

Test that it works:
https://github.com/pytorch/pytorch-canary/pull/114#issuecomment-1165922791

Test that normal validations like PR is broken:
https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1165930037

Test that normal merge works:
https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1166751288

Test that force merge works:
https://github.com/pytorch/pytorch-canary/pull/113#issuecomment-1167507356
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77943
Approved by: https://github.com/janeyx99
This commit is contained in:
zengk95
2022-06-27 22:28:12 +00:00
committed by PyTorch MergeBot
parent 79c2dfcd8e
commit 2acd2317b8
6 changed files with 106 additions and 27 deletions

View File

@ -10,7 +10,7 @@ from datetime import datetime
from dataclasses import dataclass
from urllib.request import urlopen, Request
from urllib.error import HTTPError
from typing import cast, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Iterable, cast, Any, Callable, Dict, List, Optional, Tuple, Union
from gitutils import get_git_remote_name, get_git_repo_dir, patterns_to_regex, GitRepo
from functools import lru_cache
from warnings import warn
@ -312,7 +312,6 @@ RE_REVERT_CMD = re.compile(r"@pytorch(merge|)bot\s+revert\s+this")
RE_REVERT_CMD_CLI = re.compile(r"@pytorch(merge|)bot\s+revert\s+(-m.*-c.*|-c.*-m.*)")
RE_DIFF_REV = re.compile(r'^Differential Revision:.+?(D[0-9]+)', re.MULTILINE)
def _fetch_url(url: str, *,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
@ -332,7 +331,6 @@ def _fetch_url(url: str, *,
print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}")
raise
def fetch_json(url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
@ -341,6 +339,13 @@ def fetch_json(url: str,
url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items())
return cast(List[Dict[str, Any]], _fetch_url(url, headers=headers, data=data, reader=json.load))
def fetch_json_dict(url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None) -> Dict[str, Any] :
headers = {'Accept': 'application/vnd.github.v3+json'}
if params is not None and len(params) > 0:
url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items())
return cast(Dict[str, Any], _fetch_url(url, headers=headers, data=data, reader=json.load))
def _gh_post_comment(url: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
if dry_run:
@ -395,6 +400,7 @@ def parse_args() -> Any:
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--on-green", action="store_true")
parser.add_argument("--on-mandatory", action="store_true")
parser.add_argument("--land-checks", action="store_true")
parser.add_argument("--revert", action="store_true")
parser.add_argument("--force", action="store_true")
parser.add_argument("--comment-id", type=int)
@ -733,11 +739,28 @@ class GitHubPR:
msg += f"Approved by: {approved_by_urls}\n"
return msg
def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = False, comment_id: Optional[int] = None) -> None:
def merge_into(self, repo: GitRepo, *,
force: bool = False,
dry_run: bool = False,
comment_id: Optional[int] = None) -> None:
# Raises exception if matching rule is not found
find_matching_merge_rule(self, repo, force=force, skip_internal_checks=can_skip_internal_checks(self, comment_id))
if repo.current_branch() != self.default_branch():
repo.checkout(self.default_branch())
self.merge_changes(repo, force, comment_id)
repo.push(self.default_branch(), dry_run)
gh_post_pr_comment(self.org, self.project, self.pr_num,
f"@{self.get_pr_creator_login()} your PR has been successfully merged.", dry_run)
if not dry_run:
gh_add_labels(self.org, self.project, self.pr_num, ["merged"])
def merge_changes(self,
repo: GitRepo,
force: bool = False,
comment_id: Optional[int] = None,
branch: Optional[str] = None) -> None:
branch_to_merge_into = self.default_branch() if branch is None else branch
if repo.current_branch() != branch_to_merge_into:
repo.checkout(branch_to_merge_into)
if not self.is_ghstack_pr():
msg = self.gen_commit_message()
pr_branch_name = f"__pull-request-{self.pr_num}__init__"
@ -747,9 +770,24 @@ class GitHubPR:
else:
self.merge_ghstack_into(repo, force, comment_id=comment_id)
repo.push(self.default_branch(), dry_run)
if not dry_run:
gh_add_labels(self.org, self.project, self.pr_num, ["merged"])
def create_land_time_check_branch(self,
repo: GitRepo,
branch: str,
force: bool = False,
comment_id: Optional[int] = None,) -> str:
self.merge_changes(repo, branch=branch, force=force, comment_id=comment_id)
land_check_branch = f'landchecks/{self.pr_num}'
try:
repo._run_git('branch', "-D", land_check_branch)
except Exception:
pass
repo._run_git('checkout', "-b", land_check_branch)
repo._run_git('push', '-u', 'origin', land_check_branch, '--force')
commit = repo.get_commit('HEAD').commit_hash
gh_post_pr_comment(self.org, self.project, self.pr_num,
'Successfully started land time checks.' +
f' See progress here: https://hud.pytorch.org/{self.org}/{self.project}/commit/{commit}')
return commit
class MandatoryChecksMissingError(Exception):
@ -838,21 +876,10 @@ def find_matching_merge_rule(pr: GitHubPR,
reject_reason = (f"Matched rule {rule_name}, but PR #{pr.pr_num} was not reviewed yet by any of: " +
f"{', '.join(list(rule_approvers_set)[:5])}{', ...' if len(rule_approvers_set) > 5 else ''}")
continue
if rule.mandatory_checks_name is not None:
pending_checks: List[Tuple[str, Optional[str]]] = []
failed_checks: List[Tuple[str, Optional[str]]] = []
checks = pr.get_checkrun_conclusions()
# HACK: We don't want to skip CLA check, even when forced
for checkname in filter(lambda x: force is False or "CLA Check" in x, rule.mandatory_checks_name):
if checkname not in checks:
pending_checks.append((checkname, None))
elif checks[checkname][0] is None:
pending_checks.append((checkname, checks[checkname][1]))
elif checks[checkname][0] != 'SUCCESS':
failed_checks.append((checkname, checks[checkname][1]))
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
mandatory_checks = rule.mandatory_checks_name if rule.mandatory_checks_name is not None else []
checks = pr.get_checkrun_conclusions()
required_checks = filter(lambda x: force is False or "CLA Check" in x, mandatory_checks)
[pending_checks, failed_checks] = categorize_checks(checks, required_checks)
if len(failed_checks) > 0:
if reject_reason_score < 30000:
@ -874,6 +901,9 @@ def find_matching_merge_rule(pr: GitHubPR,
raise RuntimeError(reject_reason)
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
def pr_get_checks_with_lambda(pr: GitHubPR, status_check: Callable[[Optional[str]], bool]) -> List[Tuple[str, str]]:
checks = pr.get_checkrun_conclusions()
return [(name, status[1]) for name, status in checks.items() if status_check(status[0])]
@ -938,7 +968,6 @@ def try_revert(repo: GitRepo, pr: GitHubPR, *,
def prefix_with_github_url(suffix_str: str) -> str:
return f"https://github.com/{suffix_str}"
def check_for_sev(org: str, project: str, force: bool) -> None:
if force:
return
@ -959,6 +988,38 @@ def check_for_sev(org: str, project: str, force: bool) -> None:
)
return
def fetch_check_run_conclusions(repo: GitRepo, commit: str) -> Dict[str, Tuple[str, str]]:
[owner, name] = repo.gh_owner_and_name()
checks = fetch_json_dict(f'https://api.github.com/repos/{owner}/{name}/commits/{commit}/check-runs')
check_run_conclusions = {}
if len(checks) == 0:
raise MandatoryChecksMissingError("Refusing to merge as land check(s) are not yet run")
for check_run in checks['check_runs']:
check_run_conclusions[check_run['name']] = (check_run['conclusion'],
check_run['html_url'])
return check_run_conclusions
def validate_land_time_checks(repo: GitRepo, commit: str) -> None:
checks = fetch_check_run_conclusions(repo, commit)
[pending_checks, failed_checks] = categorize_checks(checks, checks)
if len(failed_checks) > 0:
raise RuntimeError(f"Failed to merge; some land checks failed: {checks_to_str(failed_checks)}")
if len(pending_checks) > 0:
raise MandatoryChecksMissingError(f"Refusing to merge as land check(s) {checks_to_str(pending_checks)} are not yet run")
def categorize_checks(check_runs: Dict[str, Tuple[str, str]],
required_checks: Iterable[str]) -> Tuple[List[Tuple[str, Optional[str]]], List[Tuple[str, Optional[str]]]]:
pending_checks: List[Tuple[str, Optional[str]]] = []
failed_checks: List[Tuple[str, Optional[str]]] = []
for checkname in required_checks:
if checkname not in check_runs:
pending_checks.append((checkname, None))
elif check_runs[checkname][0] is None:
pending_checks.append((checkname, check_runs[checkname][1]))
elif check_runs[checkname][0].upper() != 'SUCCESS' and check_runs[checkname][0].upper() != 'SKIPPED':
failed_checks.append((checkname, check_runs[checkname][1]))
return (pending_checks, failed_checks)
def merge(pr_num: int, repo: GitRepo,
dry_run: bool = False,
@ -966,6 +1027,7 @@ def merge(pr_num: int, repo: GitRepo,
comment_id: Optional[int] = None,
mandatory_only: bool = False,
on_green: bool = False,
land_checks: bool = False,
timeout_minutes: int = 400,
stale_pr_days: int = 3) -> None:
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
@ -979,6 +1041,9 @@ def merge(pr_num: int, repo: GitRepo,
if (datetime.utcnow() - pr.last_pushed_at()).days > stale_pr_days:
raise RuntimeError("This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again.")
if land_checks:
commit = pr.create_land_time_check_branch(repo, 'viable/strict', force=force, comment_id=comment_id)
start_time = time.time()
last_exception = ''
elapsed_time = 0.0
@ -994,12 +1059,16 @@ def merge(pr_num: int, repo: GitRepo,
find_matching_merge_rule(pr, repo)
pending = pr_get_pending_checks(pr)
failing = pr_get_failed_checks(pr)
if (not mandatory_only and on_green) and len(failing) > 0:
raise RuntimeError(f"{len(failing)} additional jobs have failed, first few of them are: " +
' ,'.join(f"[{x[0]}]({x[1]})" for x in failing[:5]))
if (not mandatory_only and on_green) and len(pending) > 0:
raise MandatoryChecksMissingError(f"Still waiting for {len(pending)} additional jobs to finish, " +
f"first few of them are: {' ,'.join(x[0] for x in pending[:5])}")
if land_checks:
validate_land_time_checks(repo, commit)
return pr.merge_into(repo, dry_run=dry_run, force=force, comment_id=comment_id)
except MandatoryChecksMissingError as ex:
last_exception = str(ex)
@ -1052,11 +1121,11 @@ def main() -> None:
force=args.force,
comment_id=args.comment_id,
on_green=args.on_green,
mandatory_only=args.on_mandatory)
mandatory_only=args.on_mandatory,
land_checks=args.land_checks)
except Exception as e:
handle_exception(e)
if __name__ == "__main__":
main()