Files
pytorch/.github/scripts/delete_old_branches.py
Catherine Lee c09cf29d7d [ez][BE] Tag deletion script to delete any old ciflow + autorevert tags (#157468)
Change the branch/tag deletion script that runs once per day to delete more tags

Previous: only delete ciflow tags that didn't correspond to an open PR
New: delete ciflow tags attached to commits that are > 7 days old.  Also delete `trunk/<sha>` (I think they are for autorevert) tags that are attached to commits that are > 7 days old

It's hard to figure out when the actual tag was pushed or created, so instead it looks at the commit date, which might lead to unexpected behavior if the tag was pushed much later than the commit (ex triggering periodic later to bisect).  I think it's ok though since you don't really need the tag after the workflow runs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157468
Approved by: https://github.com/izaitsevfb
2025-07-02 20:42:32 +00:00

322 lines
11 KiB
Python

# Delete old branches
import os
import re
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable
from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo
SEC_IN_DAY = 24 * 60 * 60
CLOSED_PR_RETENTION = 30 * SEC_IN_DAY
NO_PR_RETENTION = 1.5 * 365 * SEC_IN_DAY
PR_WINDOW = 90 * SEC_IN_DAY # Set to None to look at all PRs (may take a lot of tokens)
REPO_OWNER = "pytorch"
REPO_NAME = "pytorch"
ESTIMATED_TOKENS = [0]
TOKEN = os.environ["GITHUB_TOKEN"]
if not TOKEN:
raise Exception("GITHUB_TOKEN is not set") # noqa: TRY002
REPO_ROOT = Path(__file__).parents[2]
# Query for all PRs instead of just closed/merged because it's faster
GRAPHQL_ALL_PRS_BY_UPDATED_AT = """
query ($owner: String!, $repo: String!, $cursor: String) {
repository(owner: $owner, name: $repo) {
pullRequests(
first: 100
after: $cursor
orderBy: {field: UPDATED_AT, direction: DESC}
) {
totalCount
pageInfo {
hasNextPage
endCursor
}
nodes {
headRefName
number
updatedAt
state
}
}
}
}
"""
GRAPHQL_OPEN_PRS = """
query ($owner: String!, $repo: String!, $cursor: String) {
repository(owner: $owner, name: $repo) {
pullRequests(
first: 100
after: $cursor
states: [OPEN]
) {
totalCount
pageInfo {
hasNextPage
endCursor
}
nodes {
headRefName
number
updatedAt
state
}
}
}
}
"""
GRAPHQL_NO_DELETE_BRANCH_LABEL = """
query ($owner: String!, $repo: String!, $cursor: String) {
repository(owner: $owner, name: $repo) {
label(name: "no-delete-branch") {
pullRequests(first: 100, after: $cursor) {
totalCount
pageInfo {
hasNextPage
endCursor
}
nodes {
headRefName
number
updatedAt
state
}
}
}
}
}
"""
def is_protected(branch: str) -> bool:
try:
ESTIMATED_TOKENS[0] += 1
res = gh_fetch_json_dict(
f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/branches/{branch}"
)
return bool(res["protected"])
except Exception as e:
print(f"[{branch}] Failed to fetch branch protections: {e}")
return True
def convert_gh_timestamp(date: str) -> float:
return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()
def get_branches(repo: GitRepo) -> dict[str, Any]:
# Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
git_response = repo._run_git(
"for-each-ref",
"--sort=creatordate",
"--format=%(refname) %(committerdate:iso-strict)",
"refs/remotes/origin",
)
branches_by_base_name: dict[str, Any] = {}
for line in git_response.splitlines():
branch, date = line.split(" ")
re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
assert re_branch
branch = branch_base_name = re_branch.group(1)
if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch):
branch_base_name = x.group(1)
date = datetime.fromisoformat(date).timestamp()
if branch_base_name not in branches_by_base_name:
branches_by_base_name[branch_base_name] = [date, [branch]]
else:
branches_by_base_name[branch_base_name][1].append(branch)
if date > branches_by_base_name[branch_base_name][0]:
branches_by_base_name[branch_base_name][0] = date
return branches_by_base_name
def paginate_graphql(
query: str,
kwargs: dict[str, Any],
termination_func: Callable[[list[dict[str, Any]]], bool],
get_data: Callable[[dict[str, Any]], list[dict[str, Any]]],
get_page_info: Callable[[dict[str, Any]], dict[str, Any]],
) -> list[Any]:
hasNextPage = True
endCursor = None
data: list[dict[str, Any]] = []
while hasNextPage:
ESTIMATED_TOKENS[0] += 1
res = gh_graphql(query, cursor=endCursor, **kwargs)
data.extend(get_data(res))
hasNextPage = get_page_info(res)["hasNextPage"]
endCursor = get_page_info(res)["endCursor"]
if termination_func(data):
break
return data
def get_recent_prs() -> dict[str, Any]:
now = datetime.now().timestamp()
# Grab all PRs updated in last CLOSED_PR_RETENTION days
pr_infos: list[dict[str, Any]] = paginate_graphql(
GRAPHQL_ALL_PRS_BY_UPDATED_AT,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: (
PR_WINDOW is not None
and (now - convert_gh_timestamp(data[-1]["updatedAt"]) > PR_WINDOW)
),
lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
)
# Get the most recent PR for each branch base (group gh together)
prs_by_branch_base = {}
for pr in pr_infos:
pr["updatedAt"] = convert_gh_timestamp(pr["updatedAt"])
branch_base_name = pr["headRefName"]
if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch_base_name):
branch_base_name = x.group(1)
if branch_base_name not in prs_by_branch_base:
prs_by_branch_base[branch_base_name] = pr
else:
if pr["updatedAt"] > prs_by_branch_base[branch_base_name]["updatedAt"]:
prs_by_branch_base[branch_base_name] = pr
return prs_by_branch_base
@lru_cache(maxsize=1)
def get_open_prs() -> list[dict[str, Any]]:
return paginate_graphql(
GRAPHQL_OPEN_PRS,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: False,
lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
)
def get_branches_with_magic_label_or_open_pr() -> set[str]:
pr_infos: list[dict[str, Any]] = paginate_graphql(
GRAPHQL_NO_DELETE_BRANCH_LABEL,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: False,
lambda res: res["data"]["repository"]["label"]["pullRequests"]["nodes"],
lambda res: res["data"]["repository"]["label"]["pullRequests"]["pageInfo"],
)
pr_infos.extend(get_open_prs())
# Get the most recent PR for each branch base (group gh together)
branch_bases = set()
for pr in pr_infos:
branch_base_name = pr["headRefName"]
if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch_base_name):
branch_base_name = x.group(1)
branch_bases.add(branch_base_name)
return branch_bases
def delete_branch(repo: GitRepo, branch: str) -> None:
repo._run_git("push", "origin", "-d", branch)
def delete_branches() -> None:
now = datetime.now().timestamp()
git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True)
branches = get_branches(git_repo)
prs_by_branch = get_recent_prs()
keep_branches = get_branches_with_magic_label_or_open_pr()
delete = []
# Do not delete if:
# * associated PR is open, closed but updated recently, or contains the magic string
# * no associated PR and branch was updated in last 1.5 years
# * is protected
# Setting different values of PR_WINDOW will change how branches with closed
# PRs are treated depending on how old the branch is. The default value of
# 90 will allow branches with closed PRs to be deleted if the PR hasn't been
# updated in 90 days and the branch hasn't been updated in 1.5 years
for base_branch, (date, sub_branches) in branches.items():
print(f"[{base_branch}] Updated {(now - date) / SEC_IN_DAY} days ago")
if base_branch in keep_branches:
print(f"[{base_branch}] Has magic label or open PR, skipping")
continue
pr = prs_by_branch.get(base_branch)
if pr:
print(
f"[{base_branch}] Has PR {pr['number']}: {pr['state']}, updated {(now - pr['updatedAt']) / SEC_IN_DAY} days ago"
)
if (
now - pr["updatedAt"] < CLOSED_PR_RETENTION
or (now - date) < CLOSED_PR_RETENTION
):
continue
elif now - date < NO_PR_RETENTION:
continue
print(f"[{base_branch}] Checking for branch protections")
if any(is_protected(sub_branch) for sub_branch in sub_branches):
print(f"[{base_branch}] Is protected")
continue
for sub_branch in sub_branches:
print(f"[{base_branch}] Deleting {sub_branch}")
delete.append(sub_branch)
if ESTIMATED_TOKENS[0] > 400:
print("Estimated tokens exceeded, exiting")
break
print(f"To delete ({len(delete)}):")
for branch in delete:
print(f"About to delete branch {branch}")
delete_branch(git_repo, branch)
def delete_old_tags() -> None:
# Deletes ciflow tags if they are associated with a closed PR or a specific
# commit. Lightweight tags don't have information about the date they were
# created, so we can't check how old they are. The script just assumes that
# ciflow tags should be deleted regardless of creation date.
git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True)
def delete_tag(tag: str) -> None:
print(f"Deleting tag {tag}")
ESTIMATED_TOKENS[0] += 1
delete_branch(git_repo, f"refs/tags/{tag}")
tags = git_repo._run_git("tag").splitlines()
CIFLOW_TAG_REGEX = re.compile(r"^ciflow\/.*\/(\d{5,6}|[0-9a-f]{40})$")
AUTO_REVERT_TAG_REGEX = re.compile(r"^trunk\/[0-9a-f]{40}$")
for tag in tags:
try:
if ESTIMATED_TOKENS[0] > 400:
print("Estimated tokens exceeded, exiting")
break
if not CIFLOW_TAG_REGEX.match(tag) and not AUTO_REVERT_TAG_REGEX.match(tag):
continue
# This checks the date of the commit associated with the tag instead
# of the tag itself since lightweight tags don't have this
# information. I think it should be ok since this only runs once a
# day
tag_info = git_repo._run_git("show", "-s", "--format=%ct", tag)
tag_timestamp = int(tag_info.strip())
# Maybe some timezone issues, but a few hours shouldn't matter
tag_age_days = (datetime.now().timestamp() - tag_timestamp) / SEC_IN_DAY
if tag_age_days > 7:
print(f"[{tag}] Tag is older than 7 days, deleting")
delete_tag(tag)
except Exception as e:
print(f"Failed to check tag {tag}: {e}")
if __name__ == "__main__":
delete_branches()
delete_old_tags()