mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Add pytorchbot to list of approvers for file * Add labels to the auto created PR The auto generated PR is currently not merging due to some failing tests on slow workflow that were supposed to be moved back to normal idk if this has much value, clearly we've been managing without the update Pull Request resolved: https://github.com/pytorch/pytorch/pull/135390 Approved by: https://github.com/ZainRizvi
219 lines
6.1 KiB
Python
219 lines
6.1 KiB
Python
import json
|
|
import os
|
|
import subprocess
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, cast, Dict, List, Optional, Tuple
|
|
|
|
import requests
|
|
import rockset # type: ignore[import]
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
QUERY = """
|
|
WITH most_recent_strict_commits AS (
|
|
SELECT
|
|
push.head_commit.id as sha,
|
|
FROM
|
|
commons.push
|
|
WHERE
|
|
push.ref = 'refs/heads/viable/strict'
|
|
AND push.repository.full_name = 'pytorch/pytorch'
|
|
ORDER BY
|
|
push._event_time DESC
|
|
LIMIT
|
|
3
|
|
), workflows AS (
|
|
SELECT
|
|
id
|
|
FROM
|
|
commons.workflow_run w
|
|
INNER JOIN most_recent_strict_commits c on w.head_sha = c.sha
|
|
WHERE
|
|
w.name != 'periodic'
|
|
),
|
|
job AS (
|
|
SELECT
|
|
j.id
|
|
FROM
|
|
commons.workflow_job j
|
|
INNER JOIN workflows w on w.id = j.run_id
|
|
WHERE
|
|
j.name NOT LIKE '%asan%'
|
|
),
|
|
duration_per_job AS (
|
|
SELECT
|
|
test_run.classname,
|
|
test_run.name,
|
|
job.id,
|
|
SUM(time) as time
|
|
FROM
|
|
commons.test_run_s3 test_run
|
|
/* `test_run` is ginormous and `job` is small, so lookup join is essential */
|
|
INNER JOIN job ON test_run.job_id = job.id HINT(join_strategy = lookup)
|
|
WHERE
|
|
/* cpp tests do not populate `file` for some reason. */
|
|
/* Exclude them as we don't include them in our slow test infra */
|
|
test_run.file IS NOT NULL
|
|
/* do some more filtering to cut down on the test_run size */
|
|
AND test_run.skipped IS NULL
|
|
AND test_run.failure IS NULL
|
|
AND test_run.error IS NULL
|
|
GROUP BY
|
|
test_run.classname,
|
|
test_run.name,
|
|
job.id
|
|
)
|
|
SELECT
|
|
CONCAT(
|
|
name,
|
|
' (__main__.',
|
|
classname,
|
|
')'
|
|
) as test_name,
|
|
AVG(time) as avg_duration_sec
|
|
FROM
|
|
duration_per_job
|
|
GROUP BY
|
|
CONCAT(
|
|
name,
|
|
' (__main__.',
|
|
classname,
|
|
')'
|
|
)
|
|
HAVING
|
|
AVG(time) > 60.0
|
|
ORDER BY
|
|
test_name
|
|
"""
|
|
|
|
|
|
UPDATEBOT_TOKEN = os.environ["UPDATEBOT_TOKEN"]
|
|
PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
|
|
|
|
|
|
def git_api(
|
|
url: str, params: Dict[str, Any], type: str = "get", token: str = UPDATEBOT_TOKEN
|
|
) -> Any:
|
|
headers = {
|
|
"Accept": "application/vnd.github.v3+json",
|
|
"Authorization": f"token {token}",
|
|
}
|
|
if type == "post":
|
|
return requests.post(
|
|
f"https://api.github.com{url}",
|
|
data=json.dumps(params),
|
|
headers=headers,
|
|
).json()
|
|
elif type == "patch":
|
|
return requests.patch(
|
|
f"https://api.github.com{url}",
|
|
data=json.dumps(params),
|
|
headers=headers,
|
|
).json()
|
|
else:
|
|
return requests.get(
|
|
f"https://api.github.com{url}",
|
|
params=params,
|
|
headers=headers,
|
|
).json()
|
|
|
|
|
|
def make_pr(source_repo: str, params: Dict[str, Any]) -> int:
|
|
response = git_api(f"/repos/{source_repo}/pulls", params, type="post")
|
|
print(f"made pr {response['html_url']}")
|
|
return cast(int, response["number"])
|
|
|
|
|
|
def approve_pr(source_repo: str, pr_number: int) -> None:
|
|
params = {"event": "APPROVE"}
|
|
# use pytorchbot to approve the pr
|
|
git_api(
|
|
f"/repos/{source_repo}/pulls/{pr_number}/reviews",
|
|
params,
|
|
type="post",
|
|
token=PYTORCHBOT_TOKEN,
|
|
)
|
|
|
|
|
|
def make_comment(source_repo: str, pr_number: int, msg: str) -> None:
|
|
params = {"body": msg}
|
|
# comment with pytorchbot because pytorchmergebot gets ignored
|
|
git_api(
|
|
f"/repos/{source_repo}/issues/{pr_number}/comments",
|
|
params,
|
|
type="post",
|
|
token=PYTORCHBOT_TOKEN,
|
|
)
|
|
|
|
|
|
def add_labels(source_repo: str, pr_number: int, labels: List[str]) -> None:
|
|
params = {"labels": labels}
|
|
git_api(
|
|
f"/repos/{source_repo}/issues/{pr_number}/labels",
|
|
params,
|
|
type="post",
|
|
)
|
|
|
|
|
|
def search_for_open_pr(
|
|
source_repo: str, search_string: str
|
|
) -> Optional[Tuple[int, str]]:
|
|
params = {
|
|
"q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}",
|
|
"sort": "created",
|
|
}
|
|
response = git_api("/search/issues", params)
|
|
if response["total_count"] != 0:
|
|
# pr does exist
|
|
pr_num = response["items"][0]["number"]
|
|
link = response["items"][0]["html_url"]
|
|
response = git_api(f"/repos/{source_repo}/pulls/{pr_num}", {})
|
|
branch_name = response["head"]["ref"]
|
|
print(
|
|
f"pr does exist, number is {pr_num}, branch name is {branch_name}, link is {link}"
|
|
)
|
|
return pr_num, branch_name
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
rs_client = rockset.RocksetClient(
|
|
host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
|
|
)
|
|
|
|
results = rs_client.sql(QUERY).results
|
|
slow_tests = {row["test_name"]: row["avg_duration_sec"] for row in results}
|
|
|
|
with open(REPO_ROOT / "test" / "slow_tests.json", "w") as f:
|
|
json.dump(slow_tests, f, indent=2)
|
|
|
|
branch_name = f"update_slow_tests_{int(time.time())}"
|
|
pr_num = None
|
|
|
|
open_pr = search_for_open_pr("pytorch/pytorch", "Update slow tests")
|
|
if open_pr is not None:
|
|
pr_num, branch_name = open_pr
|
|
|
|
subprocess.run(["git", "checkout", "-b", branch_name], cwd=REPO_ROOT)
|
|
subprocess.run(["git", "add", "test/slow_tests.json"], cwd=REPO_ROOT)
|
|
subprocess.run(["git", "commit", "-m", "Update slow tests"], cwd=REPO_ROOT)
|
|
subprocess.run(
|
|
f"git push --set-upstream origin {branch_name} -f".split(), cwd=REPO_ROOT
|
|
)
|
|
|
|
params = {
|
|
"title": "Update slow tests",
|
|
"head": branch_name,
|
|
"base": "main",
|
|
"body": "This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/"
|
|
+ ".github/workflows/weekly.yml).\nUpdate the list of slow tests.",
|
|
}
|
|
if pr_num is None:
|
|
# no existing pr, so make a new one and approve it
|
|
pr_num = make_pr("pytorch/pytorch", params)
|
|
time.sleep(5)
|
|
add_labels("pytorch/pytorch", pr_num, ["ciflow/slow", "ci-no-td"])
|
|
approve_pr("pytorch/pytorch", pr_num)
|
|
make_comment("pytorch/pytorch", pr_num, "@pytorchbot merge")
|