mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Re make of https://github.com/pytorch/pytorch/pull/140587 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140578 Approved by: https://github.com/huydhn
63 lines
1.9 KiB
Python
Executable File
63 lines
1.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Check whether a PR has required labels."""
|
|
|
|
import sys
|
|
from typing import Any
|
|
|
|
from github_utils import gh_delete_comment, gh_post_pr_comment
|
|
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
|
from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG
|
|
from trymerge import GitHubPR
|
|
|
|
|
|
def delete_all_label_err_comments(pr: "GitHubPR") -> None:
|
|
for comment in pr.get_comments():
|
|
if is_label_err_comment(comment):
|
|
gh_delete_comment(pr.org, pr.project, comment.database_id)
|
|
|
|
|
|
def add_label_err_comment(pr: "GitHubPR") -> None:
|
|
# Only make a comment if one doesn't exist already
|
|
if not any(is_label_err_comment(comment) for comment in pr.get_comments()):
|
|
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG)
|
|
|
|
|
|
def parse_args() -> Any:
|
|
from argparse import ArgumentParser
|
|
|
|
parser = ArgumentParser("Check PR labels")
|
|
parser.add_argument("pr_num", type=int)
|
|
# add a flag to return a non-zero exit code if the PR does not have the required labels
|
|
parser.add_argument(
|
|
"--exit-non-zero",
|
|
action="store_true",
|
|
help="Return a non-zero exit code if the PR does not have the required labels",
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
|
org, project = repo.gh_owner_and_name()
|
|
pr = GitHubPR(org, project, args.pr_num)
|
|
|
|
try:
|
|
if not has_required_labels(pr):
|
|
print(LABEL_ERR_MSG, flush=True)
|
|
add_label_err_comment(pr)
|
|
if args.exit_non_zero:
|
|
raise RuntimeError("PR does not have required labels")
|
|
else:
|
|
delete_all_label_err_comments(pr)
|
|
except Exception as e:
|
|
if args.exit_non_zero:
|
|
raise RuntimeError(f"Error checking labels: {e}") from e
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|