mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move check_label ci to mergebot (#92309)
Fixes #88098 ### What Changed * Moved `check_label.py` logic into `trymerge.py` * Refactored relevant unittests * ~~Dropped~~ Refactored `check_label.py` ci job ### Tests `python .github/scripts/test_trymerge.py` `python .github/scripts/test_check_labels.py` `make lint & lintrunner -a` ### Notes to reviewers This PR replaces the [original PR](https://github.com/pytorch/pytorch/pull/92225) to workaround the sticky EasyCLA failure mark on its first commit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92309 Approved by: https://github.com/ZainRizvi
This commit is contained in:
committed by
PyTorch MergeBot
parent
b33d9e2c87
commit
190f7803f5
49
.github/scripts/trymerge.py
vendored
49
.github/scripts/trymerge.py
vendored
@ -14,6 +14,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
@ -33,6 +34,7 @@ from gitutils import (
|
||||
get_git_repo_dir,
|
||||
patterns_to_regex,
|
||||
)
|
||||
from export_pytorch_labels import get_pytorch_labels
|
||||
from trymerge_explainer import (
|
||||
TryMergeExplainer,
|
||||
get_revert_message,
|
||||
@ -422,6 +424,16 @@ CIFLOW_LABEL = re.compile(r"^ciflow/.+")
|
||||
CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
|
||||
MERGE_RULE_PATH = Path(".github") / "merge_rules.yaml"
|
||||
|
||||
BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"]
|
||||
|
||||
LABEL_ERR_MSG_TITLE = "This PR needs a label"
|
||||
LABEL_ERR_MSG = (
|
||||
f"# {LABEL_ERR_MSG_TITLE}\n"
|
||||
"If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`.\n\n" # noqa: E501 pylint: disable=line-too-long
|
||||
"If not, please add the `topic: not user facing` label.\n\n"
|
||||
"For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work." # noqa: E501 pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
|
||||
def _fetch_url(url: str, *,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
@ -1397,10 +1409,34 @@ def validate_land_time_checks(org: str, project: str, commit: str) -> None:
|
||||
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||
return len(list(filter(pattern.match, labels))) > 0
|
||||
|
||||
def categorize_checks(
|
||||
check_runs: JobNameToStateDict,
|
||||
required_checks: List[str],
|
||||
) -> Tuple[List[Tuple[str, Optional[str]]], List[Tuple[str, Optional[str]]]]:
|
||||
def get_release_notes_labels() -> List[str]:
|
||||
return [label for label in get_pytorch_labels() if label.lstrip().startswith("release notes:")]
|
||||
|
||||
def has_required_labels(pr: GitHubPR) -> bool:
|
||||
pr_labels = pr.get_labels()
|
||||
# Check if PR is not user facing
|
||||
is_not_user_facing_pr = any(label.strip() == "topic: not user facing" for label in pr_labels)
|
||||
return is_not_user_facing_pr or any(label.strip() in get_release_notes_labels() for label in pr_labels)
|
||||
|
||||
def delete_comment(comment_id: int) -> None:
|
||||
url = f"https://api.github.com/repos/pytorch/pytorch/issues/comments/{comment_id}"
|
||||
_fetch_url(url, method="DELETE")
|
||||
|
||||
def is_label_err_comment(comment: GitHubComment) -> bool:
|
||||
return comment.body_text.lstrip(" #").startswith(LABEL_ERR_MSG_TITLE) and comment.author_login in BOT_AUTHORS
|
||||
|
||||
def delete_all_label_err_comments(pr: GitHubPR) -> None:
|
||||
for comment in pr.get_comments():
|
||||
if is_label_err_comment(comment):
|
||||
delete_comment(comment.database_id)
|
||||
|
||||
def add_label_err_comment(pr: GitHubPR) -> None:
|
||||
# Only make a comment if one doesn't exist already
|
||||
if not any(is_label_err_comment(comment) for comment in pr.get_comments()):
|
||||
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG)
|
||||
|
||||
def categorize_checks(check_runs: Dict[str, JobCheckState],
|
||||
required_checks: Iterable[str]) -> Tuple[List[Tuple[str, Optional[str]]], List[Tuple[str, Optional[str]]]]:
|
||||
pending_checks: List[Tuple[str, Optional[str]]] = []
|
||||
failed_checks: List[Tuple[str, Optional[str]]] = []
|
||||
|
||||
@ -1456,6 +1492,11 @@ def merge(pr_num: int, repo: GitRepo,
|
||||
# here to stop the merge process right away
|
||||
find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)
|
||||
|
||||
if not has_required_labels(pr):
|
||||
raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))
|
||||
else:
|
||||
delete_all_label_err_comments(pr)
|
||||
|
||||
if land_checks and not dry_run:
|
||||
land_check_commit = pr.create_land_time_check_branch(
|
||||
repo,
|
||||
|
Reference in New Issue
Block a user