Files
pytorch/.github/scripts/github_utils.py
Nikita Shulga 90c0825e2d [GHF] Allow reverts from pytorch-auto-revert app (#164911)
This is a bit weird, but author_login is not a unique field, but author_url is.

Explicitly allow https://github.com/apps/pytorch-auto-revert to issue revert commands

Update mocks by running
```
sed -i -e s/8e262b0495bd934d39dda198d4c09144311c5ddd6cca6a227194bd48dbfe7201/47860a8f57a214a426d1150c29893cbc2aa49507f12b731483b1a1254bca3428/ gql_mocks.json
```

Test plan: Run
```python
from trymerge import GitHubPR
pr=GitHubPR("pytorch", "pytorch", 164660)
print(pr.get_last_comment().author_url, pr.get_comment_by_id(3375785595).author_url)
```
that should produce
```
https://github.com/pytorch-auto-revert https://github.com/apps/pytorch-auto-revert
```
Plus added a regression test that checks two particular comments for revert validity

`pytorch-auto-revert` user is my alter ego :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164911
Approved by: https://github.com/jeanschmidt
2025-10-08 15:15:45 +00:00

230 lines
7.1 KiB
Python

"""GitHub Utilities"""
import json
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, cast, Optional, Union
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen
GITHUB_API_URL = "https://api.github.com"
@dataclass
class GitHubComment:
body_text: str
created_at: str
author_login: str
author_url: Optional[str]
author_association: str
editor_login: Optional[str]
database_id: int
url: str
def gh_fetch_url_and_headers(
url: str,
*,
headers: Optional[dict[str, str]] = None,
data: Union[Optional[dict[str, Any]], str] = None,
method: Optional[str] = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
) -> tuple[Any, Any]:
if headers is None:
headers = {}
token = os.environ.get("GITHUB_TOKEN")
if token is not None and url.startswith(f"{GITHUB_API_URL}/"):
headers["Authorization"] = f"token {token}"
data_ = None
if data is not None:
data_ = data.encode() if isinstance(data, str) else json.dumps(data).encode()
try:
with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
return conn.headers, reader(conn)
except HTTPError as err:
if (
err.code == 403
and all(
key in err.headers
for key in ["X-RateLimit-Limit", "X-RateLimit-Remaining"]
)
and int(err.headers["X-RateLimit-Remaining"]) == 0
):
print(
f"""{url}
Rate limit exceeded:
Used: {err.headers["X-RateLimit-Used"]}
Limit: {err.headers["X-RateLimit-Limit"]}
Remaining: {err.headers["X-RateLimit-Remaining"]}
Resets at: {err.headers["x-RateLimit-Reset"]}"""
)
else:
print(f"Error fetching {url} {err}")
raise
def gh_fetch_url(
url: str,
*,
headers: Optional[dict[str, str]] = None,
data: Union[Optional[dict[str, Any]], str] = None,
method: Optional[str] = None,
reader: Callable[[Any], Any] = json.load,
) -> Any:
return gh_fetch_url_and_headers(
url, headers=headers, data=data, reader=reader, method=method
)[1]
def gh_fetch_json(
url: str,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
method: Optional[str] = None,
) -> list[dict[str, Any]]:
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
url += "?" + "&".join(
f"{name}={quote(str(val))}" for name, val in params.items()
)
return cast(
list[dict[str, Any]],
gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
)
def _gh_fetch_json_any(
url: str,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
) -> Any:
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
url += "?" + "&".join(
f"{name}={quote(str(val))}" for name, val in params.items()
)
return gh_fetch_url(url, headers=headers, data=data, reader=json.load)
def gh_fetch_json_list(
url: str,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
) -> list[dict[str, Any]]:
return cast(list[dict[str, Any]], _gh_fetch_json_any(url, params, data))
def gh_fetch_json_dict(
url: str,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
return cast(dict[str, Any], _gh_fetch_json_any(url, params, data))
def gh_graphql(query: str, **kwargs: Any) -> dict[str, Any]:
rc = gh_fetch_url(
"https://api.github.com/graphql", # @lint-ignore
data={"query": query, "variables": kwargs},
reader=json.load,
)
if "errors" in rc:
raise RuntimeError(
f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}"
)
return cast(dict[str, Any], rc)
def _gh_post_comment(
url: str, comment: str, dry_run: bool = False
) -> list[dict[str, Any]]:
if dry_run:
print(comment)
return []
return gh_fetch_json_list(url, data={"body": comment})
def gh_post_pr_comment(
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
) -> list[dict[str, Any]]:
return _gh_post_comment(
f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments",
comment,
dry_run,
)
def gh_post_commit_comment(
org: str, repo: str, sha: str, comment: str, dry_run: bool = False
) -> list[dict[str, Any]]:
return _gh_post_comment(
f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments",
comment,
dry_run,
)
def gh_close_pr(org: str, repo: str, pr_num: int, dry_run: bool = False) -> None:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/pulls/{pr_num}"
if dry_run:
print(f"Dry run closing PR {pr_num}")
else:
gh_fetch_url(url, method="PATCH", data={"state": "closed"})
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/comments/{comment_id}"
gh_fetch_url(url, method="DELETE", reader=lambda x: x.read())
def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str:
merge_base = ""
# Get the merge base using the GitHub REST API. This is the same as using
# git merge-base without the need to have git. The API doc can be found at
# https://docs.github.com/en/rest/commits/commits?apiVersion=2022-11-28#compare-two-commits
try:
json_data = gh_fetch_url(
f"{GITHUB_API_URL}/repos/{org}/{repo}/compare/{base}...{head}",
headers={"Accept": "application/vnd.github.v3+json"},
reader=json.load,
)
if json_data:
merge_base = json_data.get("merge_base_commit", {}).get("sha", "")
else:
warnings.warn(
f"Failed to get merge base for {base}...{head}: Empty response"
)
except Exception as error:
warnings.warn(f"Failed to get merge base for {base}...{head}: {error}")
return merge_base
def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") -> None:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/pulls/{pr_num}"
try:
gh_fetch_url(url, method="PATCH", data={"state": state})
except HTTPError as err:
# When trying to open the pull request, error 422 means that the branch
# has been deleted and the API couldn't re-open it
if err.code == 422 and state == "open":
warnings.warn(
f"Failed to open {pr_num} because its head branch has been deleted: {err}"
)
else:
raise
def gh_query_issues_by_labels(
org: str, repo: str, labels: list[str], state: str = "open"
) -> list[dict[str, Any]]:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
return gh_fetch_json(
url, method="GET", params={"labels": ",".join(labels), "state": state}
)