mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/161560 Approved by: https://github.com/malfet ghstack dependencies: #161558, #161637
		
			
				
	
	
		
			2603 lines
		
	
	
		
			89 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			2603 lines
		
	
	
		
			89 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python3
 | |
| 
 | |
| # NB: the following functions are used in Meta-internal workflows
 | |
| # (github_first_try_merge/my_handler.py) and thus have functionality limitations
 | |
| # (no `git` command access, no network access besides the strict allow list):
 | |
| #
 | |
| # find_matching_merge_rule
 | |
| # read_merge_rules
 | |
| #
 | |
| # Also any signature changes of these functions, as well as changes to the `GitHubPR`
 | |
| # class, will likely require corresponding changes for the internal workflows.
 | |
| 
 | |
| import base64
 | |
| import json
 | |
| import os
 | |
| import re
 | |
| import time
 | |
| import urllib.parse
 | |
| from collections import defaultdict
 | |
| from collections.abc import Iterable
 | |
| from dataclasses import dataclass
 | |
| from functools import cache
 | |
| from pathlib import Path
 | |
| from re import Pattern
 | |
| from typing import Any, Callable, cast, NamedTuple, Optional
 | |
| from warnings import warn
 | |
| 
 | |
| import yaml
 | |
| from github_utils import (
 | |
|     gh_close_pr,
 | |
|     gh_fetch_json_list,
 | |
|     gh_fetch_merge_base,
 | |
|     gh_fetch_url,
 | |
|     gh_graphql,
 | |
|     gh_post_commit_comment,
 | |
|     gh_post_pr_comment,
 | |
|     gh_update_pr_state,
 | |
|     GitHubComment,
 | |
| )
 | |
| from gitutils import (
 | |
|     are_ghstack_branches_in_sync,
 | |
|     get_git_remote_name,
 | |
|     get_git_repo_dir,
 | |
|     GitRepo,
 | |
|     patterns_to_regex,
 | |
|     retries_decorator,
 | |
| )
 | |
| from label_utils import (
 | |
|     gh_add_labels,
 | |
|     gh_remove_label,
 | |
|     has_required_labels,
 | |
|     LABEL_ERR_MSG,
 | |
| )
 | |
| from trymerge_explainer import get_revert_message, TryMergeExplainer
 | |
| 
 | |
| 
 | |
| # labels
 | |
| MERGE_IN_PROGRESS_LABEL = "merging"
 | |
| MERGE_COMPLETE_LABEL = "merged"
 | |
| 
 | |
| 
 | |
| class JobCheckState(NamedTuple):
 | |
|     name: str
 | |
|     url: str
 | |
|     status: Optional[str]
 | |
|     classification: Optional[str]
 | |
|     job_id: Optional[int]
 | |
|     title: Optional[str]
 | |
|     summary: Optional[str]
 | |
| 
 | |
| 
 | |
| JobNameToStateDict = dict[str, JobCheckState]
 | |
| 
 | |
| 
 | |
| class WorkflowCheckState:
 | |
|     def __init__(self, name: str, url: str, run_id: int, status: Optional[str]):
 | |
|         self.name: str = name
 | |
|         self.url: str = url
 | |
|         self.run_id: int = run_id
 | |
|         self.status: Optional[str] = status
 | |
|         self.jobs: JobNameToStateDict = {}
 | |
| 
 | |
| 
 | |
| GH_PR_REVIEWS_FRAGMENT = """
 | |
| fragment PRReviews on PullRequestReviewConnection {
 | |
|   nodes {
 | |
|     author {
 | |
|       login
 | |
|     }
 | |
|     bodyText
 | |
|     createdAt
 | |
|     authorAssociation
 | |
|     editor {
 | |
|       login
 | |
|     }
 | |
|     databaseId
 | |
|     url
 | |
|     state
 | |
|   }
 | |
|   pageInfo {
 | |
|     startCursor
 | |
|     hasPreviousPage
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| GH_CHECKSUITES_FRAGMENT = """
 | |
| fragment PRCheckSuites on CheckSuiteConnection {
 | |
|   edges {
 | |
|     node {
 | |
|       workflowRun {
 | |
|         workflow {
 | |
|           name
 | |
|           databaseId
 | |
|         }
 | |
|         databaseId
 | |
|         url
 | |
|       }
 | |
|       checkRuns(first: 50) {
 | |
|         nodes {
 | |
|           name
 | |
|           conclusion
 | |
|           detailsUrl
 | |
|           databaseId
 | |
|           title
 | |
|           summary
 | |
|         }
 | |
|         pageInfo {
 | |
|           endCursor
 | |
|           hasNextPage
 | |
|         }
 | |
|       }
 | |
|       conclusion
 | |
|     }
 | |
|     cursor
 | |
|   }
 | |
|   pageInfo {
 | |
|     hasNextPage
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| GH_COMMIT_AUTHORS_FRAGMENT = """
 | |
| fragment CommitAuthors on PullRequestCommitConnection {
 | |
|   nodes {
 | |
|     commit {
 | |
|       authors(first: 2) {
 | |
|         nodes {
 | |
|           user {
 | |
|             login
 | |
|           }
 | |
|           email
 | |
|           name
 | |
|         }
 | |
|       }
 | |
|       oid
 | |
|     }
 | |
|   }
 | |
|   pageInfo {
 | |
|     endCursor
 | |
|     hasNextPage
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| GH_GET_PR_INFO_QUERY = (
 | |
|     GH_PR_REVIEWS_FRAGMENT
 | |
|     + GH_CHECKSUITES_FRAGMENT
 | |
|     + GH_COMMIT_AUTHORS_FRAGMENT
 | |
|     + """
 | |
| query ($owner: String!, $name: String!, $number: Int!) {
 | |
|   repository(owner: $owner, name: $name) {
 | |
|     pullRequest(number: $number) {
 | |
|       closed
 | |
|       isCrossRepository
 | |
|       author {
 | |
|         login
 | |
|       }
 | |
|       title
 | |
|       body
 | |
|       headRefName
 | |
|       headRepository {
 | |
|         nameWithOwner
 | |
|       }
 | |
|       baseRefName
 | |
|       baseRefOid
 | |
|       baseRepository {
 | |
|         nameWithOwner
 | |
|         isPrivate
 | |
|         defaultBranchRef {
 | |
|           name
 | |
|         }
 | |
|       }
 | |
|       mergeCommit {
 | |
|         oid
 | |
|       }
 | |
|       commits_with_authors: commits(first: 100) {
 | |
|         ...CommitAuthors
 | |
|         totalCount
 | |
|       }
 | |
|       commits(last: 1) {
 | |
|         nodes {
 | |
|           commit {
 | |
|             checkSuites(first: 10) {
 | |
|               ...PRCheckSuites
 | |
|             }
 | |
|             status {
 | |
|               contexts {
 | |
|                 context
 | |
|                 state
 | |
|                 targetUrl
 | |
|               }
 | |
|             }
 | |
|             oid
 | |
|           }
 | |
|         }
 | |
|       }
 | |
|       changedFiles
 | |
|       files(first: 100) {
 | |
|         nodes {
 | |
|           path
 | |
|         }
 | |
|         pageInfo {
 | |
|           endCursor
 | |
|           hasNextPage
 | |
|         }
 | |
|       }
 | |
|       reviews(last: 100) {
 | |
|         ...PRReviews
 | |
|       }
 | |
|       comments(last: 5) {
 | |
|         nodes {
 | |
|           bodyText
 | |
|           createdAt
 | |
|           author {
 | |
|             login
 | |
|           }
 | |
|           authorAssociation
 | |
|           editor {
 | |
|             login
 | |
|           }
 | |
|           databaseId
 | |
|           url
 | |
|         }
 | |
|         pageInfo {
 | |
|           startCursor
 | |
|           hasPreviousPage
 | |
|         }
 | |
|       }
 | |
|       labels(first: 100) {
 | |
|         edges {
 | |
|           node {
 | |
|             name
 | |
|           }
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| )
 | |
| 
 | |
| GH_GET_PR_NEXT_FILES_QUERY = """
 | |
| query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
 | |
|   repository(name: $name, owner: $owner) {
 | |
|     pullRequest(number: $number) {
 | |
|       files(first: 100, after: $cursor) {
 | |
|         nodes {
 | |
|           path
 | |
|         }
 | |
|         pageInfo {
 | |
|           endCursor
 | |
|           hasNextPage
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| GH_GET_PR_NEXT_CHECKSUITES = (
 | |
|     GH_CHECKSUITES_FRAGMENT
 | |
|     + """
 | |
| query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
 | |
|   repository(name: $name, owner: $owner) {
 | |
|     pullRequest(number: $number) {
 | |
|       commits(last: 1) {
 | |
|         nodes {
 | |
|           commit {
 | |
|             oid
 | |
|             checkSuites(first: 10, after: $cursor) {
 | |
|               ...PRCheckSuites
 | |
|             }
 | |
|           }
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| )
 | |
| 
 | |
| GH_GET_PR_NEXT_CHECK_RUNS = """
 | |
| query ($owner: String!, $name: String!, $number: Int!, $cs_cursor: String, $cr_cursor: String!) {
 | |
|   repository(name: $name, owner: $owner) {
 | |
|     pullRequest(number: $number) {
 | |
|       commits(last: 1) {
 | |
|         nodes {
 | |
|           commit {
 | |
|             oid
 | |
|             checkSuites(first: 1, after: $cs_cursor) {
 | |
|               nodes {
 | |
|                 checkRuns(first: 100, after: $cr_cursor) {
 | |
|                   nodes {
 | |
|                     name
 | |
|                     conclusion
 | |
|                     detailsUrl
 | |
|                     databaseId
 | |
|                     title
 | |
|                     summary
 | |
|                   }
 | |
|                   pageInfo {
 | |
|                     endCursor
 | |
|                     hasNextPage
 | |
|                   }
 | |
|                 }
 | |
|               }
 | |
|             }
 | |
|           }
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| GH_GET_PR_PREV_COMMENTS = """
 | |
| query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
 | |
|   repository(name: $name, owner: $owner) {
 | |
|     pullRequest(number: $number) {
 | |
|       comments(last: 100, before: $cursor) {
 | |
|         nodes {
 | |
|           bodyText
 | |
|           createdAt
 | |
|           author {
 | |
|             login
 | |
|           }
 | |
|           authorAssociation
 | |
|           editor {
 | |
|             login
 | |
|           }
 | |
|           databaseId
 | |
|           url
 | |
|         }
 | |
|         pageInfo {
 | |
|           startCursor
 | |
|           hasPreviousPage
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| # This query needs read-org permission
 | |
| GH_GET_TEAM_MEMBERS_QUERY = """
 | |
| query($org: String!, $name: String!, $cursor: String) {
 | |
|   organization(login: $org) {
 | |
|     team(slug: $name) {
 | |
|       members(first: 100, after: $cursor) {
 | |
|         nodes {
 | |
|           login
 | |
|         }
 | |
|         pageInfo {
 | |
|           hasNextPage
 | |
|           endCursor
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| GH_GET_PR_NEXT_AUTHORS_QUERY = (
 | |
|     GH_COMMIT_AUTHORS_FRAGMENT
 | |
|     + """
 | |
| query ($owner: String!, $name: String!, $number: Int!, $cursor: String) {
 | |
|   repository(name: $name, owner: $owner) {
 | |
|     pullRequest(number: $number) {
 | |
|       commits_with_authors: commits(first: 100, after: $cursor) {
 | |
|         ...CommitAuthors
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| )
 | |
| 
 | |
| GH_GET_PR_PREV_REVIEWS_QUERY = (
 | |
|     GH_PR_REVIEWS_FRAGMENT
 | |
|     + """
 | |
| query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
 | |
|   repository(name: $name, owner: $owner) {
 | |
|     pullRequest(number: $number) {
 | |
|       reviews(last: 100, before: $cursor) {
 | |
|         ...PRReviews
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| )
 | |
| 
 | |
| GH_GET_REPO_SUBMODULES = """
 | |
| query ($owner: String!, $name: String!) {
 | |
|   repository(owner: $owner, name: $name) {
 | |
|     submodules(first: 100) {
 | |
|       nodes {
 | |
|         path
 | |
|       }
 | |
|       pageInfo {
 | |
|         endCursor
 | |
|         hasNextPage
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$")
 | |
| RE_GHSTACK_DESC = re.compile(r"Stack.*:\r?\n(\* [^\r\n]+\r?\n)+", re.MULTILINE)
 | |
| RE_PULL_REQUEST_RESOLVED = re.compile(
 | |
|     r"(Pull Request resolved|Pull-Request-resolved|Pull-Request): "
 | |
|     r"https://github.com/(?P<owner>[^/]+)/(?P<repo>[^/]+)/pull/(?P<number>[0-9]+)",
 | |
|     re.MULTILINE,
 | |
| )
 | |
| RE_PR_CC_LINE = re.compile(r"^cc:? @\w+.*\r?\n?$", re.MULTILINE)
 | |
| RE_DIFF_REV = re.compile(r"^Differential Revision:.+?(D[0-9]+)", re.MULTILINE)
 | |
| CIFLOW_LABEL = re.compile(r"^ciflow/.+")
 | |
| CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
 | |
| MERGE_RULE_PATH = Path(".github") / "merge_rules.yaml"
 | |
| REMOTE_MAIN_BRANCH = "origin/main"
 | |
| DRCI_CHECKRUN_NAME = "Dr.CI"
 | |
| INTERNAL_CHANGES_CHECKRUN_NAME = "Meta Internal-Only Changes Check"
 | |
| HAS_NO_CONNECTED_DIFF_TITLE = (
 | |
|     "There is no internal Diff connected, this can be merged now"
 | |
| )
 | |
| # This could be set to -1 to ignore all flaky and broken trunk failures. On the
 | |
| # other hand, using a large value like 10 here might be useful in sev situation
 | |
| IGNORABLE_FAILED_CHECKS_THESHOLD = 10
 | |
| 
 | |
| 
 | |
| def iter_issue_timeline_until_comment(
 | |
|     org: str, repo: str, issue_number: int, target_comment_id: int, max_pages: int = 200
 | |
| ) -> Any:
 | |
|     """
 | |
|     Yield timeline entries in order until (and including) the entry whose id == target_comment_id
 | |
|     for a 'commented' event. Stops once the target comment is encountered.
 | |
|     """
 | |
|     page = 1
 | |
| 
 | |
|     while page <= max_pages:
 | |
|         url = (
 | |
|             f"https://api.github.com/repos/{org}/{repo}/issues/{issue_number}/timeline"
 | |
|         )
 | |
|         params = {"per_page": 100, "page": page}
 | |
| 
 | |
|         batch = gh_fetch_json_list(url, params)
 | |
| 
 | |
|         if not batch:
 | |
|             return
 | |
|         for ev in batch:
 | |
|             # The target is the issue comment row with event == "commented" and id == issue_comment_id
 | |
|             if ev.get("event") == "commented" and ev.get("id") == target_comment_id:
 | |
|                 yield ev  # nothing in the timeline after this matters, so stop early
 | |
|                 return
 | |
|             yield ev
 | |
|         if len(batch) < 100:
 | |
|             return
 | |
|         page += 1
 | |
| 
 | |
|     # If we got here without finding the comment, then we either hit a bug or some github PR
 | |
|     # has a _really_ long timeline.
 | |
|     # The max # of pages found on any pytorch/pytorch PR at the time of this change was 41
 | |
|     raise RuntimeError(
 | |
|         f"Could not find a merge commit in the first {max_pages} pages of the timeline at url {url}."
 | |
|         f"This is most likely a bug, please report it to the @pytorch/pytorch-dev-infra team."
 | |
|     )
 | |
| 
 | |
| 
 | |
| def sha_from_committed_event(ev: dict[str, Any]) -> Optional[str]:
 | |
|     """Extract SHA from committed event in timeline"""
 | |
|     return ev.get("sha")
 | |
| 
 | |
| 
 | |
| def sha_from_force_push_after(ev: dict[str, Any]) -> Optional[str]:
 | |
|     """Extract SHA from force push event in timeline"""
 | |
|     # The current GitHub API format
 | |
|     commit_id = ev.get("commit_id")
 | |
|     if commit_id:
 | |
|         return str(commit_id)
 | |
| 
 | |
|     # Legacy format
 | |
|     after = ev.get("after") or ev.get("after_commit") or {}
 | |
|     if isinstance(after, dict):
 | |
|         return after.get("sha") or after.get("oid")
 | |
|     return ev.get("after_sha") or ev.get("head_sha")
 | |
| 
 | |
| 
 | |
| def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
 | |
|     rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
 | |
|     return rc["data"]["repository"]["pullRequest"]
 | |
| 
 | |
| 
 | |
| @cache
 | |
| def gh_get_team_members(org: str, name: str) -> list[str]:
 | |
|     rc: list[str] = []
 | |
|     team_members: dict[str, Any] = {
 | |
|         "pageInfo": {"hasNextPage": "true", "endCursor": None}
 | |
|     }
 | |
|     while bool(team_members["pageInfo"]["hasNextPage"]):
 | |
|         query = gh_graphql(
 | |
|             GH_GET_TEAM_MEMBERS_QUERY,
 | |
|             org=org,
 | |
|             name=name,
 | |
|             cursor=team_members["pageInfo"]["endCursor"],
 | |
|         )
 | |
|         team = query["data"]["organization"]["team"]
 | |
|         if team is None:
 | |
|             warn(f"Requested non-existing team {org}/{name}")
 | |
|             return []
 | |
|         team_members = team["members"]
 | |
|         rc += [member["login"] for member in team_members["nodes"]]
 | |
|     return rc
 | |
| 
 | |
| 
 | |
| def get_check_run_name_prefix(workflow_run: Any) -> str:
 | |
|     if workflow_run is None:
 | |
|         return ""
 | |
|     else:
 | |
|         return f"{workflow_run['workflow']['name']} / "
 | |
| 
 | |
| 
 | |
| def is_passing_status(status: Optional[str]) -> bool:
 | |
|     return status is not None and status.upper() in ["SUCCESS", "SKIPPED", "NEUTRAL"]
 | |
| 
 | |
| 
 | |
| def add_workflow_conclusions(
 | |
|     checksuites: Any,
 | |
|     get_next_checkruns_page: Callable[[list[dict[str, dict[str, Any]]], int, Any], Any],
 | |
|     get_next_checksuites: Callable[[Any], Any],
 | |
| ) -> JobNameToStateDict:
 | |
|     # graphql seems to favor the most recent workflow run, so in theory we
 | |
|     # shouldn't need to account for reruns, but do it just in case
 | |
| 
 | |
|     # workflow -> job -> job info
 | |
|     workflows: dict[str, WorkflowCheckState] = {}
 | |
| 
 | |
|     # for the jobs that don't have a workflow
 | |
|     no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)
 | |
| 
 | |
|     def add_conclusions(edges: Any) -> None:
 | |
|         for edge_idx, edge in enumerate(edges):
 | |
|             node = edge["node"]
 | |
|             workflow_run = node["workflowRun"]
 | |
|             checkruns = node["checkRuns"]
 | |
| 
 | |
|             workflow_obj: WorkflowCheckState = no_workflow_obj
 | |
| 
 | |
|             if workflow_run is not None:
 | |
|                 # This is the usual workflow run ID we see on GitHub
 | |
|                 workflow_run_id = workflow_run["databaseId"]
 | |
|                 # While this is the metadata name and ID of the workflow itself
 | |
|                 workflow_name = workflow_run["workflow"]["name"]
 | |
|                 workflow_id = workflow_run["workflow"]["databaseId"]
 | |
| 
 | |
|                 workflow_conclusion = node["conclusion"]
 | |
|                 # Do not override existing status with cancelled
 | |
|                 if workflow_conclusion == "CANCELLED" and workflow_name in workflows:
 | |
|                     continue
 | |
| 
 | |
|                 # Only keep the latest workflow run for each workflow, heuristically,
 | |
|                 # it's the run with largest run ID
 | |
|                 if (
 | |
|                     workflow_id not in workflows
 | |
|                     or workflows[workflow_id].run_id < workflow_run_id
 | |
|                 ):
 | |
|                     workflows[workflow_id] = WorkflowCheckState(
 | |
|                         name=workflow_name,
 | |
|                         status=workflow_conclusion,
 | |
|                         url=workflow_run["url"],
 | |
|                         run_id=workflow_run_id,
 | |
|                     )
 | |
|                 workflow_obj = workflows[workflow_id]
 | |
| 
 | |
|             while checkruns is not None:
 | |
|                 for checkrun_node in checkruns["nodes"]:
 | |
|                     if not isinstance(checkrun_node, dict):
 | |
|                         warn(f"Expected dictionary, but got {type(checkrun_node)}")
 | |
|                         continue
 | |
|                     checkrun_name = f"{get_check_run_name_prefix(workflow_run)}{checkrun_node['name']}"
 | |
|                     existing_checkrun = workflow_obj.jobs.get(checkrun_name)
 | |
|                     if existing_checkrun is None or not is_passing_status(
 | |
|                         existing_checkrun.status
 | |
|                     ):
 | |
|                         workflow_obj.jobs[checkrun_name] = JobCheckState(
 | |
|                             checkrun_name,
 | |
|                             checkrun_node["detailsUrl"],
 | |
|                             checkrun_node["conclusion"],
 | |
|                             classification=None,
 | |
|                             job_id=checkrun_node["databaseId"],
 | |
|                             title=checkrun_node["title"],
 | |
|                             summary=checkrun_node["summary"],
 | |
|                         )
 | |
| 
 | |
|                 if bool(checkruns["pageInfo"]["hasNextPage"]):
 | |
|                     checkruns = get_next_checkruns_page(edges, edge_idx, checkruns)
 | |
|                 else:
 | |
|                     checkruns = None
 | |
| 
 | |
|     all_edges = checksuites["edges"].copy()
 | |
|     while bool(checksuites["pageInfo"]["hasNextPage"]):
 | |
|         checksuites = get_next_checksuites(checksuites)
 | |
|         all_edges.extend(checksuites["edges"])
 | |
| 
 | |
|     add_conclusions(all_edges)
 | |
| 
 | |
|     # Flatten the dictionaries.  If there exists jobs in the workflow run, put
 | |
|     # the jobs in but don't put the workflow in.  We care more about the jobs in
 | |
|     # the workflow that ran than the container workflow.
 | |
|     res: JobNameToStateDict = {}
 | |
|     for workflow in workflows.values():
 | |
|         if len(workflow.jobs) > 0:
 | |
|             for job_name, job in workflow.jobs.items():
 | |
|                 res[job_name] = job
 | |
|         else:
 | |
|             res[workflow.name] = JobCheckState(
 | |
|                 workflow.name,
 | |
|                 workflow.url,
 | |
|                 workflow.status,
 | |
|                 classification=None,
 | |
|                 job_id=None,
 | |
|                 title=None,
 | |
|                 summary=None,
 | |
|             )
 | |
|     for job_name, job in no_workflow_obj.jobs.items():
 | |
|         res[job_name] = job
 | |
|     return res
 | |
| 
 | |
| 
 | |
| def parse_args() -> Any:
 | |
|     from argparse import ArgumentParser
 | |
| 
 | |
|     parser = ArgumentParser("Merge PR into default branch")
 | |
|     parser.add_argument("--dry-run", action="store_true")
 | |
|     parser.add_argument("--revert", action="store_true")
 | |
|     parser.add_argument("--force", action="store_true")
 | |
|     parser.add_argument("--ignore-current", action="store_true")
 | |
|     parser.add_argument("--check-mergeability", action="store_true")
 | |
|     parser.add_argument("--comment-id", type=int)
 | |
|     parser.add_argument("--reason", type=str)
 | |
|     parser.add_argument("pr_num", type=int)
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -> bool:
 | |
|     if comment_id is None:
 | |
|         return False
 | |
|     comment = pr.get_comment_by_id(comment_id)
 | |
|     if comment.editor_login is not None:
 | |
|         return False
 | |
|     return comment.author_login == "facebook-github-bot"
 | |
| 
 | |
| 
 | |
| def _revlist_to_prs(
 | |
|     repo: GitRepo,
 | |
|     pr: "GitHubPR",
 | |
|     rev_list: Iterable[str],
 | |
|     should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
 | |
| ) -> list[tuple["GitHubPR", str]]:
 | |
|     rc: list[tuple[GitHubPR, str]] = []
 | |
|     for idx, rev in enumerate(rev_list):
 | |
|         msg = repo.commit_message(rev)
 | |
|         # findall doesn't return named captures, so we need to use finditer
 | |
|         all_matches = list(RE_PULL_REQUEST_RESOLVED.finditer(msg))
 | |
|         if len(all_matches) != 1:
 | |
|             raise RuntimeError(
 | |
|                 f"Found an unexpected number of PRs mentioned in commit {rev}: "
 | |
|                 f"{len(all_matches)}.  This is probably because you are using an "
 | |
|                 "old version of ghstack.  Please update ghstack and resubmit "
 | |
|                 "your PRs"
 | |
|             )
 | |
| 
 | |
|         m = all_matches[0]
 | |
|         if pr.org != m.group("owner") or pr.project != m.group("repo"):
 | |
|             raise RuntimeError(
 | |
|                 f"PR {m.group('number')} resolved to wrong owner/repo pair"
 | |
|             )
 | |
|         pr_num = int(m.group("number"))
 | |
|         candidate = GitHubPR(pr.org, pr.project, pr_num) if pr_num != pr.pr_num else pr
 | |
|         if should_skip is not None and should_skip(idx, candidate):
 | |
|             continue
 | |
|         rc.append((candidate, rev))
 | |
|     return rc
 | |
| 
 | |
| 
 | |
| def get_ghstack_prs(
 | |
|     repo: GitRepo, pr: "GitHubPR", open_only: bool = True
 | |
| ) -> list[tuple["GitHubPR", str]]:
 | |
|     """
 | |
|     Get the PRs in the stack that are below this PR (inclusive).  Throws error if any of the open PRs are out of sync.
 | |
|     @:param open_only: Only return open PRs
 | |
|     """
 | |
|     # For ghstack, cherry-pick commits based from origin
 | |
|     orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
 | |
|     rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
 | |
| 
 | |
|     def skip_func(idx: int, candidate: "GitHubPR") -> bool:
 | |
|         if not open_only or not candidate.is_closed():
 | |
|             return False
 | |
|         print(
 | |
|             f"Skipping {idx + 1} of {len(rev_list)} PR (#{candidate.pr_num}) as its already been merged"
 | |
|         )
 | |
|         return True
 | |
| 
 | |
|     assert pr.is_ghstack_pr()
 | |
|     entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func)
 | |
|     print(
 | |
|         f"Found {len(entire_stack)} PRs in the stack for {pr.pr_num}: {[x[0].pr_num for x in entire_stack]}"
 | |
|     )
 | |
| 
 | |
|     for stacked_pr, rev in entire_stack:
 | |
|         if stacked_pr.is_closed():
 | |
|             continue
 | |
|         base_ref = stacked_pr.base_ref()
 | |
|         if base_ref == pr.default_branch():
 | |
|             base_ref = repo.get_merge_base(
 | |
|                 f"{repo.remote}/{base_ref}", f"{repo.remote}/{stacked_pr.head_ref()}"
 | |
|             )
 | |
|         if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref(), base_ref):
 | |
|             raise RuntimeError(
 | |
|                 f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on "
 | |
|                 + f"branch {stacked_pr.get_ghstack_orig_ref()} that would be merged into {stacked_pr.default_branch()}.  "
 | |
|                 + "This usually happens because there is a non ghstack change in the PR.  "
 | |
|                 + f"Please sync them and try again (ex. make the changes on {orig_ref} and run ghstack)."
 | |
|             )
 | |
|     return entire_stack
 | |
| 
 | |
| 
 | |
| class GitHubPR:
 | |
|     def __init__(self, org: str, project: str, pr_num: int) -> None:
 | |
|         assert isinstance(pr_num, int)
 | |
|         self.org = org
 | |
|         self.project = project
 | |
|         self.pr_num = pr_num
 | |
|         self.info = gh_get_pr_info(org, project, pr_num)
 | |
|         self.changed_files: Optional[list[str]] = None
 | |
|         self.labels: Optional[list[str]] = None
 | |
|         self.conclusions: Optional[JobNameToStateDict] = None
 | |
|         self.comments: Optional[list[GitHubComment]] = None
 | |
|         self._authors: Optional[list[tuple[str, str]]] = None
 | |
|         self._reviews: Optional[list[tuple[str, str]]] = None
 | |
|         self.merge_base: Optional[str] = None
 | |
|         self.submodules: Optional[list[str]] = None
 | |
| 
 | |
|     def is_closed(self) -> bool:
 | |
|         return bool(self.info["closed"])
 | |
| 
 | |
|     def is_cross_repo(self) -> bool:
 | |
|         return bool(self.info["isCrossRepository"])
 | |
| 
 | |
|     def base_ref(self) -> str:
 | |
|         return cast(str, self.info["baseRefName"])
 | |
| 
 | |
|     def default_branch(self) -> str:
 | |
|         return cast(str, self.info["baseRepository"]["defaultBranchRef"]["name"])
 | |
| 
 | |
|     def head_ref(self) -> str:
 | |
|         return cast(str, self.info["headRefName"])
 | |
| 
 | |
|     def is_ghstack_pr(self) -> bool:
 | |
|         return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
 | |
| 
 | |
|     def get_ghstack_orig_ref(self) -> str:
 | |
|         assert self.is_ghstack_pr()
 | |
|         return re.sub(r"/head$", "/orig", self.head_ref())
 | |
| 
 | |
|     def is_base_repo_private(self) -> bool:
 | |
|         return bool(self.info["baseRepository"]["isPrivate"])
 | |
| 
 | |
|     def get_changed_files_count(self) -> int:
 | |
|         return int(self.info["changedFiles"])
 | |
| 
 | |
|     def last_commit(self) -> Any:
 | |
|         return self.info["commits"]["nodes"][-1]["commit"]
 | |
| 
 | |
|     def last_commit_sha(self, default: Optional[str] = None) -> str:
 | |
|         # for commits, the oid is the sha
 | |
| 
 | |
|         if default is None:
 | |
|             return str(self.last_commit()["oid"])
 | |
| 
 | |
|         return str(self.last_commit().get("oid", default))
 | |
| 
 | |
|     def get_merge_base(self) -> str:
 | |
|         if self.merge_base:
 | |
|             return self.merge_base
 | |
| 
 | |
|         last_commit_sha = self.last_commit_sha()
 | |
|         # NB: We could use self.base_ref() here for regular PR, however, that doesn't
 | |
|         # work for ghstack where the base is the custom branch, i.e. gh/USER/ID/base,
 | |
|         # so let's just use main instead
 | |
|         self.merge_base = gh_fetch_merge_base(
 | |
|             self.org, self.project, last_commit_sha, self.default_branch()
 | |
|         )
 | |
| 
 | |
|         # Fallback to baseRefOid if the API call fails, i.e. rate limit. Note that baseRefOid
 | |
|         # points to the base ref associated with the PR or, in other words, the head of main
 | |
|         # when the PR is created or rebased. This is not necessarily the merge base commit,
 | |
|         # but it could serve as a fallback in most cases and it's readily available as part
 | |
|         # of the PR info
 | |
|         if not self.merge_base:
 | |
|             self.merge_base = cast(str, self.info["baseRefOid"])
 | |
| 
 | |
|         return self.merge_base
 | |
| 
 | |
|     def get_changed_files(self) -> list[str]:
 | |
|         if self.changed_files is None:
 | |
|             info = self.info
 | |
|             unique_changed_files = set()
 | |
|             # Do not try to fetch more than 10K files
 | |
|             for _ in range(100):
 | |
|                 unique_changed_files.update([x["path"] for x in info["files"]["nodes"]])
 | |
|                 if not info["files"]["pageInfo"]["hasNextPage"]:
 | |
|                     break
 | |
|                 rc = gh_graphql(
 | |
|                     GH_GET_PR_NEXT_FILES_QUERY,
 | |
|                     name=self.project,
 | |
|                     owner=self.org,
 | |
|                     number=self.pr_num,
 | |
|                     cursor=info["files"]["pageInfo"]["endCursor"],
 | |
|                 )
 | |
|                 info = rc["data"]["repository"]["pullRequest"]
 | |
|             self.changed_files = list(unique_changed_files)
 | |
| 
 | |
|         if len(self.changed_files) != self.get_changed_files_count():
 | |
|             raise RuntimeError("Changed file count mismatch")
 | |
|         return self.changed_files
 | |
| 
 | |
|     def get_submodules(self) -> list[str]:
 | |
|         if self.submodules is None:
 | |
|             rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
 | |
|             info = rc["data"]["repository"]["submodules"]
 | |
|             self.submodules = [s["path"] for s in info["nodes"]]
 | |
|         return self.submodules
 | |
| 
 | |
|     def get_changed_submodules(self) -> list[str]:
 | |
|         submodules = self.get_submodules()
 | |
|         return [f for f in self.get_changed_files() if f in submodules]
 | |
| 
 | |
|     def has_invalid_submodule_updates(self) -> bool:
 | |
|         """Submodule updates in PR are invalid if submodule keyword
 | |
|         is not mentioned in neither the title nor body/description
 | |
|         nor in any of the labels.
 | |
|         """
 | |
|         return (
 | |
|             len(self.get_changed_submodules()) > 0
 | |
|             and "submodule" not in self.get_title().lower()
 | |
|             and "submodule" not in self.get_body().lower()
 | |
|             and all("submodule" not in label for label in self.get_labels())
 | |
|         )
 | |
| 
 | |
|     def _get_reviews(self) -> list[tuple[str, str]]:
 | |
|         if self._reviews is None:
 | |
|             self._reviews = []
 | |
|             info = self.info
 | |
|             for _ in range(100):
 | |
|                 nodes = info["reviews"]["nodes"]
 | |
|                 self._reviews = [
 | |
|                     (node["author"]["login"], node["state"]) for node in nodes
 | |
|                 ] + self._reviews
 | |
|                 if not info["reviews"]["pageInfo"]["hasPreviousPage"]:
 | |
|                     break
 | |
|                 rc = gh_graphql(
 | |
|                     GH_GET_PR_PREV_REVIEWS_QUERY,
 | |
|                     name=self.project,
 | |
|                     owner=self.org,
 | |
|                     number=self.pr_num,
 | |
|                     cursor=info["reviews"]["pageInfo"]["startCursor"],
 | |
|                 )
 | |
|                 info = rc["data"]["repository"]["pullRequest"]
 | |
|         reviews = {
 | |
|             author: state for author, state in self._reviews if state != "COMMENTED"
 | |
|         }
 | |
|         return list(reviews.items())
 | |
| 
 | |
|     def get_approved_by(self) -> list[str]:
 | |
|         return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
 | |
| 
 | |
|     def get_commit_count(self) -> int:
 | |
|         return int(self.info["commits_with_authors"]["totalCount"])
 | |
| 
 | |
|     def get_commit_sha_at_comment(self, comment_id: int) -> Optional[str]:
 | |
|         """
 | |
|         Get the PR head commit SHA that was present when a specific comment was posted.
 | |
|         This ensures we only merge the state of the PR at the time the merge command was issued,
 | |
|         not any subsequent commits that may have been pushed after.
 | |
| 
 | |
|         Returns None if no head-changing events found before the comment or if the comment was not found.
 | |
|         """
 | |
|         head = None
 | |
| 
 | |
|         try:
 | |
|             for event in iter_issue_timeline_until_comment(
 | |
|                 self.org, self.project, self.pr_num, comment_id
 | |
|             ):
 | |
|                 etype = event.get("event")
 | |
|                 if etype == "committed":
 | |
|                     sha = sha_from_committed_event(event)
 | |
|                     if sha:
 | |
|                         head = sha
 | |
|                         print(f"Timeline: Found commit event for SHA {sha}")
 | |
|                 elif etype == "head_ref_force_pushed":
 | |
|                     sha = sha_from_force_push_after(event)
 | |
|                     if sha:
 | |
|                         head = sha
 | |
|                         print(f"Timeline: Found force push event for SHA {sha}")
 | |
|                 elif etype == "commented":
 | |
|                     if event.get("id") == comment_id:
 | |
|                         print(f"Timeline: Found final comment with sha {sha}")
 | |
|                         return head
 | |
|         except Exception as e:
 | |
|             print(
 | |
|                 f"Warning: Failed to reconstruct timeline for comment {comment_id}: {e}"
 | |
|             )
 | |
|             return None
 | |
| 
 | |
|         print(f"Did not find comment with id {comment_id} in the PR timeline")
 | |
|         return None
 | |
| 
 | |
|     def get_pr_creator_login(self) -> str:
 | |
|         return cast(str, self.info["author"]["login"])
 | |
| 
 | |
|     def _fetch_authors(self) -> list[tuple[str, str]]:
 | |
|         if self._authors is not None:
 | |
|             return self._authors
 | |
|         authors: list[tuple[str, str]] = []
 | |
| 
 | |
|         def add_authors(info: dict[str, Any]) -> None:
 | |
|             for node in info["commits_with_authors"]["nodes"]:
 | |
|                 for author_node in node["commit"]["authors"]["nodes"]:
 | |
|                     user_node = author_node["user"]
 | |
|                     author = f"{author_node['name']} <{author_node['email']}>"
 | |
|                     if user_node is None:
 | |
|                         # If author is not github user, user node will be null
 | |
|                         authors.append(("", author))
 | |
|                     else:
 | |
|                         authors.append((cast(str, user_node["login"]), author))
 | |
| 
 | |
|         info = self.info
 | |
|         for _ in range(100):
 | |
|             add_authors(info)
 | |
|             if not info["commits_with_authors"]["pageInfo"]["hasNextPage"]:
 | |
|                 break
 | |
|             rc = gh_graphql(
 | |
|                 GH_GET_PR_NEXT_AUTHORS_QUERY,
 | |
|                 name=self.project,
 | |
|                 owner=self.org,
 | |
|                 number=self.pr_num,
 | |
|                 cursor=info["commits_with_authors"]["pageInfo"]["endCursor"],
 | |
|             )
 | |
|             info = rc["data"]["repository"]["pullRequest"]
 | |
|         self._authors = authors
 | |
|         return authors
 | |
| 
 | |
|     def get_committer_login(self, num: int = 0) -> str:
 | |
|         return self._fetch_authors()[num][0]
 | |
| 
 | |
|     def get_committer_author(self, num: int = 0) -> str:
 | |
|         return self._fetch_authors()[num][1]
 | |
| 
 | |
|     def get_labels(self) -> list[str]:
 | |
|         if self.labels is not None:
 | |
|             return self.labels
 | |
|         labels = (
 | |
|             [node["node"]["name"] for node in self.info["labels"]["edges"]]
 | |
|             if "labels" in self.info
 | |
|             else []
 | |
|         )
 | |
|         self.labels = labels
 | |
|         return self.labels
 | |
| 
 | |
|     def get_checkrun_conclusions(self) -> JobNameToStateDict:
 | |
|         """Returns dict of checkrun -> [conclusion, url]"""
 | |
|         if self.conclusions is not None:
 | |
|             return self.conclusions
 | |
|         orig_last_commit = self.last_commit()
 | |
| 
 | |
|         def get_pr_next_check_runs(
 | |
|             edges: list[dict[str, dict[str, Any]]], edge_idx: int, checkruns: Any
 | |
|         ) -> Any:
 | |
|             rc = gh_graphql(
 | |
|                 GH_GET_PR_NEXT_CHECK_RUNS,
 | |
|                 name=self.project,
 | |
|                 owner=self.org,
 | |
|                 number=self.pr_num,
 | |
|                 cs_cursor=edges[edge_idx - 1]["cursor"] if edge_idx > 0 else None,
 | |
|                 cr_cursor=checkruns["pageInfo"]["endCursor"],
 | |
|             )
 | |
|             last_commit = rc["data"]["repository"]["pullRequest"]["commits"]["nodes"][
 | |
|                 -1
 | |
|             ]["commit"]
 | |
|             checkruns = last_commit["checkSuites"]["nodes"][-1]["checkRuns"]
 | |
|             return checkruns
 | |
| 
 | |
|         def get_pr_next_checksuites(checksuites: Any) -> Any:
 | |
|             rc = gh_graphql(
 | |
|                 GH_GET_PR_NEXT_CHECKSUITES,
 | |
|                 name=self.project,
 | |
|                 owner=self.org,
 | |
|                 number=self.pr_num,
 | |
|                 cursor=checksuites["edges"][-1]["cursor"],
 | |
|             )
 | |
|             info = rc["data"]["repository"]["pullRequest"]
 | |
|             last_commit = info["commits"]["nodes"][-1]["commit"]
 | |
|             if last_commit["oid"] != orig_last_commit["oid"]:
 | |
|                 raise RuntimeError("Last commit changed on PR")
 | |
|             return last_commit["checkSuites"]
 | |
| 
 | |
|         checksuites = orig_last_commit["checkSuites"]
 | |
| 
 | |
|         self.conclusions = add_workflow_conclusions(
 | |
|             checksuites, get_pr_next_check_runs, get_pr_next_checksuites
 | |
|         )
 | |
| 
 | |
|         # Append old style statuses(like ones populated by CircleCI or EasyCLA) to conclusions
 | |
|         if orig_last_commit["status"] and orig_last_commit["status"]["contexts"]:
 | |
|             for status in orig_last_commit["status"]["contexts"]:
 | |
|                 name = status["context"]
 | |
|                 self.conclusions[name] = JobCheckState(
 | |
|                     name,
 | |
|                     status["targetUrl"],
 | |
|                     status["state"],
 | |
|                     classification=None,
 | |
|                     job_id=None,
 | |
|                     title=None,
 | |
|                     summary=None,
 | |
|                 )
 | |
| 
 | |
|         # Making an exception for Apply lint auggestions/autoformat because the
 | |
|         # bot adds a merged label -> triggers workflow -> sometimes needs
 | |
|         # approval -> is read as failure, which results in a blocked merge, but
 | |
|         # this workflow doesn't provide mergability info
 | |
|         self.conclusions.pop("Apply lint suggestions", None)
 | |
| 
 | |
|         return self.conclusions
 | |
| 
 | |
|     def get_authors(self) -> dict[str, str]:
 | |
|         rc = {}
 | |
|         for idx in range(len(self._fetch_authors())):
 | |
|             rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
 | |
| 
 | |
|         return rc
 | |
| 
 | |
|     def get_author(self) -> str:
 | |
|         authors = self.get_authors()
 | |
|         if len(authors) == 1:
 | |
|             return next(iter(authors.values()))
 | |
|         creator = self.get_pr_creator_login()
 | |
|         # If PR creator is not among authors
 | |
|         # Assume it was authored by first commit author
 | |
|         if creator not in authors:
 | |
|             return self.get_committer_author(0)
 | |
|         return authors[creator]
 | |
| 
 | |
|     def get_title(self) -> str:
 | |
|         return cast(str, self.info["title"])
 | |
| 
 | |
|     def get_body(self) -> str:
 | |
|         return cast(str, self.info["body"])
 | |
| 
 | |
|     def get_merge_commit(self) -> Optional[str]:
 | |
|         mc = self.info["mergeCommit"]
 | |
|         return mc["oid"] if mc is not None else None
 | |
| 
 | |
|     def get_pr_url(self) -> str:
 | |
|         return f"https://github.com/{self.org}/{self.project}/pull/{self.pr_num}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def _comment_from_node(node: Any) -> GitHubComment:
 | |
|         editor = node["editor"]
 | |
|         return GitHubComment(
 | |
|             body_text=node["bodyText"],
 | |
|             created_at=node["createdAt"] if "createdAt" in node else "",
 | |
|             author_login=node["author"]["login"],
 | |
|             author_association=node["authorAssociation"],
 | |
|             editor_login=editor["login"] if editor else None,
 | |
|             database_id=node["databaseId"],
 | |
|             url=node["url"],
 | |
|         )
 | |
| 
 | |
|     def get_comments(self) -> list[GitHubComment]:
 | |
|         if self.comments is not None:
 | |
|             return self.comments
 | |
|         self.comments = []
 | |
|         info = self.info["comments"]
 | |
|         # Do not try to fetch more than 10K comments
 | |
|         for _ in range(100):
 | |
|             self.comments = [
 | |
|                 self._comment_from_node(node) for node in info["nodes"]
 | |
|             ] + self.comments
 | |
|             if not info["pageInfo"]["hasPreviousPage"]:
 | |
|                 break
 | |
|             rc = gh_graphql(
 | |
|                 GH_GET_PR_PREV_COMMENTS,
 | |
|                 name=self.project,
 | |
|                 owner=self.org,
 | |
|                 number=self.pr_num,
 | |
|                 cursor=info["pageInfo"]["startCursor"],
 | |
|             )
 | |
|             info = rc["data"]["repository"]["pullRequest"]["comments"]
 | |
|         return self.comments
 | |
| 
 | |
|     def get_last_comment(self) -> GitHubComment:
 | |
|         return self._comment_from_node(self.info["comments"]["nodes"][-1])
 | |
| 
 | |
|     def get_comment_by_id(self, database_id: int) -> GitHubComment:
 | |
|         if self.comments is None:
 | |
|             # Fastpath - try searching in partial prefetched comments
 | |
|             for node in self.info["comments"]["nodes"]:
 | |
|                 comment = self._comment_from_node(node)
 | |
|                 if comment.database_id == database_id:
 | |
|                     return comment
 | |
| 
 | |
|         for comment in self.get_comments():
 | |
|             if comment.database_id == database_id:
 | |
|                 return comment
 | |
| 
 | |
|         # The comment could have actually been a review left on the PR (the message written alongside the review).
 | |
|         # (This is generally done to trigger the merge right when a comment is left)
 | |
|         # Check those review comments to see if one of those was the comment in question.
 | |
|         for node in self.info["reviews"]["nodes"]:
 | |
|             # These review comments contain all the fields regular comments need
 | |
|             comment = self._comment_from_node(node)
 | |
|             if comment.database_id == database_id:
 | |
|                 return comment
 | |
| 
 | |
|         raise RuntimeError(f"Comment with id {database_id} not found")
 | |
| 
 | |
|     def get_diff_revision(self) -> Optional[str]:
 | |
|         rc = RE_DIFF_REV.search(self.get_body())
 | |
|         return rc.group(1) if rc is not None else None
 | |
| 
 | |
|     def has_internal_changes(self) -> bool:
 | |
|         checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
 | |
|         if self.get_diff_revision() is None:
 | |
|             return False
 | |
|         checks = self.get_checkrun_conclusions()
 | |
|         if checks is None or checkrun_name not in checks:
 | |
|             return False
 | |
|         return checks[checkrun_name].status != "SUCCESS"
 | |
| 
 | |
|     def has_no_connected_diff(self) -> bool:
 | |
|         checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
 | |
|         checks = self.get_checkrun_conclusions()
 | |
|         if checks is None or checkrun_name not in checks:
 | |
|             return False
 | |
|         return checks[checkrun_name].title == HAS_NO_CONNECTED_DIFF_TITLE
 | |
| 
 | |
|     def merge_ghstack_into(
 | |
|         self,
 | |
|         repo: GitRepo,
 | |
|         skip_mandatory_checks: bool,
 | |
|         comment_id: Optional[int] = None,
 | |
|         skip_all_rule_checks: bool = False,
 | |
|     ) -> list["GitHubPR"]:
 | |
|         assert self.is_ghstack_pr()
 | |
|         ghstack_prs = get_ghstack_prs(
 | |
|             repo, self, open_only=False
 | |
|         )  # raises error if out of sync
 | |
|         pr_dependencies = []
 | |
|         for pr, rev in ghstack_prs:
 | |
|             if pr.is_closed():
 | |
|                 pr_dependencies.append(pr)
 | |
|                 continue
 | |
| 
 | |
|             commit_msg = pr.gen_commit_message(
 | |
|                 filter_ghstack=True, ghstack_deps=pr_dependencies
 | |
|             )
 | |
|             if pr.pr_num != self.pr_num and not skip_all_rule_checks:
 | |
|                 # Raises exception if matching rule is not found
 | |
|                 find_matching_merge_rule(
 | |
|                     pr,
 | |
|                     repo,
 | |
|                     skip_mandatory_checks=skip_mandatory_checks,
 | |
|                     skip_internal_checks=can_skip_internal_checks(self, comment_id),
 | |
|                 )
 | |
|             repo.cherry_pick(rev)
 | |
|             repo.amend_commit_message(commit_msg)
 | |
|             pr_dependencies.append(pr)
 | |
|         return [x for x, _ in ghstack_prs if not x.is_closed()]
 | |
| 
 | |
|     def gen_commit_message(
 | |
|         self,
 | |
|         filter_ghstack: bool = False,
 | |
|         ghstack_deps: Optional[list["GitHubPR"]] = None,
 | |
|     ) -> str:
 | |
|         """Fetches title and body from PR description
 | |
|         adds reviewed by, pull request resolved and optionally
 | |
|         filters out ghstack info"""
 | |
|         # Adding the url here makes it clickable within the Github UI
 | |
|         approved_by_urls = ", ".join(
 | |
|             prefix_with_github_url(login) for login in self.get_approved_by()
 | |
|         )
 | |
|         # Remove "cc: " line from the message body
 | |
|         msg_body = re.sub(RE_PR_CC_LINE, "", self.get_body())
 | |
|         if filter_ghstack:
 | |
|             msg_body = re.sub(RE_GHSTACK_DESC, "", msg_body)
 | |
|         msg = self.get_title() + f" (#{self.pr_num})\n\n"
 | |
|         msg += msg_body
 | |
| 
 | |
|         msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
 | |
|         msg += f"Approved by: {approved_by_urls}\n"
 | |
|         if ghstack_deps:
 | |
|             msg += f"ghstack dependencies: {', '.join([f'#{pr.pr_num}' for pr in ghstack_deps])}\n"
 | |
| 
 | |
|         # Mention PR co-authors, which should be at the end of the message
 | |
|         # And separated from the body by two newlines
 | |
|         first_coauthor = True
 | |
|         for author_login, author_name in self.get_authors().items():
 | |
|             if author_login != self.get_pr_creator_login():
 | |
|                 if first_coauthor:
 | |
|                     msg, first_coauthor = (msg + "\n", False)
 | |
|                 msg += f"\nCo-authored-by: {author_name}"
 | |
| 
 | |
|         return msg
 | |
| 
 | |
|     def add_numbered_label(self, label_base: str, dry_run: bool) -> None:
 | |
|         labels = self.get_labels() if self.labels is not None else []
 | |
|         full_label = label_base
 | |
|         count = 0
 | |
|         for label in labels:
 | |
|             if label_base in label:
 | |
|                 count += 1
 | |
|                 full_label = f"{label_base}X{count}"
 | |
|         self.add_label(full_label, dry_run)
 | |
| 
 | |
|     def add_label(self, label: str, dry_run: bool) -> None:
 | |
|         gh_add_labels(self.org, self.project, self.pr_num, [label], dry_run)
 | |
| 
 | |
|     def merge_into(
 | |
|         self,
 | |
|         repo: GitRepo,
 | |
|         *,
 | |
|         skip_mandatory_checks: bool = False,
 | |
|         dry_run: bool = False,
 | |
|         comment_id: int,
 | |
|         ignore_current_checks: Optional[list[str]] = None,
 | |
|     ) -> None:
 | |
|         # Raises exception if matching rule is not found
 | |
|         (
 | |
|             merge_rule,
 | |
|             pending_checks,
 | |
|             failed_checks,
 | |
|             ignorable_checks,
 | |
|         ) = find_matching_merge_rule(
 | |
|             self,
 | |
|             repo,
 | |
|             skip_mandatory_checks=skip_mandatory_checks,
 | |
|             skip_internal_checks=can_skip_internal_checks(self, comment_id),
 | |
|             ignore_current_checks=ignore_current_checks,
 | |
|         )
 | |
|         additional_merged_prs = self.merge_changes_locally(
 | |
|             repo, skip_mandatory_checks, comment_id
 | |
|         )
 | |
| 
 | |
|         repo.push(self.default_branch(), dry_run)
 | |
|         if not dry_run:
 | |
|             self.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)
 | |
|             for pr in additional_merged_prs:
 | |
|                 pr.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)
 | |
| 
 | |
|         # When the merge process reaches this part, we can assume that the commit
 | |
|         # has been successfully pushed to trunk
 | |
|         merge_commit_sha = repo.rev_parse(name=self.default_branch())
 | |
| 
 | |
|         if comment_id and self.pr_num:
 | |
|             # Finally, upload the record to s3. The list of pending and failed
 | |
|             # checks are at the time of the merge
 | |
|             save_merge_record(
 | |
|                 comment_id=comment_id,
 | |
|                 pr_num=self.pr_num,
 | |
|                 owner=self.org,
 | |
|                 project=self.project,
 | |
|                 author=self.get_author(),
 | |
|                 pending_checks=pending_checks,
 | |
|                 failed_checks=failed_checks,
 | |
|                 ignore_current_checks=ignorable_checks.get("IGNORE_CURRENT_CHECK", []),
 | |
|                 broken_trunk_checks=ignorable_checks.get("BROKEN_TRUNK", []),
 | |
|                 flaky_checks=ignorable_checks.get("FLAKY", []),
 | |
|                 unstable_checks=ignorable_checks.get("UNSTABLE", []),
 | |
|                 last_commit_sha=self.last_commit_sha(default=""),
 | |
|                 merge_base_sha=self.get_merge_base(),
 | |
|                 merge_commit_sha=merge_commit_sha,
 | |
|                 is_failed=False,
 | |
|                 skip_mandatory_checks=skip_mandatory_checks,
 | |
|                 ignore_current=bool(ignore_current_checks),
 | |
|             )
 | |
|         else:
 | |
|             print("Missing comment ID or PR number, couldn't upload to s3")
 | |
| 
 | |
|         # Usually Github will see that the commit has "resolves <pr_num>" in the
 | |
|         # commit message and close the PR, but sometimes it doesn't, leading to
 | |
|         # confusion.  When it doesn't, we close it manually.
 | |
|         time.sleep(60)  # Give Github some time to close the PR
 | |
|         manually_close_merged_pr(
 | |
|             pr=self,
 | |
|             additional_merged_prs=additional_merged_prs,
 | |
|             merge_commit_sha=merge_commit_sha,
 | |
|             dry_run=dry_run,
 | |
|         )
 | |
| 
 | |
|     def merge_changes_locally(
 | |
|         self,
 | |
|         repo: GitRepo,
 | |
|         skip_mandatory_checks: bool = False,
 | |
|         comment_id: Optional[int] = None,
 | |
|         branch: Optional[str] = None,
 | |
|         skip_all_rule_checks: bool = False,
 | |
|     ) -> list["GitHubPR"]:
 | |
|         """
 | |
|         :param skip_all_rule_checks: If true, skips all rule checks on ghstack PRs, useful for dry-running merge locally
 | |
|         """
 | |
|         branch_to_merge_into = self.default_branch() if branch is None else branch
 | |
|         if repo.current_branch() != branch_to_merge_into:
 | |
|             repo.checkout(branch_to_merge_into)
 | |
| 
 | |
|         # It's okay to skip the commit SHA check for ghstack PRs since
 | |
|         # authoring requires write access to the repo.
 | |
|         if self.is_ghstack_pr():
 | |
|             return self.merge_ghstack_into(
 | |
|                 repo,
 | |
|                 skip_mandatory_checks,
 | |
|                 comment_id=comment_id,
 | |
|                 skip_all_rule_checks=skip_all_rule_checks,
 | |
|             )
 | |
| 
 | |
|         msg = self.gen_commit_message()
 | |
|         pr_branch_name = f"__pull-request-{self.pr_num}__init__"
 | |
| 
 | |
|         # Determine which commit SHA to merge
 | |
|         commit_to_merge = None
 | |
|         if not comment_id:
 | |
|             raise ValueError("Must provide --comment-id when merging regular PRs")
 | |
| 
 | |
|         # Get the commit SHA that was present when the comment was made
 | |
|         commit_to_merge = self.get_commit_sha_at_comment(comment_id)
 | |
|         if not commit_to_merge:
 | |
|             raise RuntimeError(
 | |
|                 f"Could not find commit that was pushed before comment {comment_id}"
 | |
|             )
 | |
| 
 | |
|         # Validate that this commit is the latest commit on the PR
 | |
|         latest_commit = self.last_commit_sha()
 | |
|         if commit_to_merge != latest_commit:
 | |
|             raise RuntimeError(
 | |
|                 f"Commit {commit_to_merge} was HEAD when comment {comment_id} was posted "
 | |
|                 f"but now the latest commit on the PR is {latest_commit}. "
 | |
|                 f"Please re-issue the merge command to merge the latest commit."
 | |
|             )
 | |
| 
 | |
|         print(f"Merging commit {commit_to_merge} locally")
 | |
| 
 | |
|         repo.fetch(commit_to_merge, pr_branch_name)
 | |
|         repo._run_git("merge", "--squash", pr_branch_name)
 | |
|         repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg)
 | |
| 
 | |
|         # Did the PR change since we started the merge?
 | |
|         pulled_sha = repo.show_ref(pr_branch_name)
 | |
|         latest_pr_status = GitHubPR(self.org, self.project, self.pr_num)
 | |
|         if (
 | |
|             pulled_sha != latest_pr_status.last_commit_sha()
 | |
|             or pulled_sha != commit_to_merge
 | |
|         ):
 | |
|             raise RuntimeError(
 | |
|                 "PR has been updated since CI checks last passed. Please rerun the merge command."
 | |
|             )
 | |
|         return []
 | |
| 
 | |
| 
 | |
| class MergeRuleFailedError(RuntimeError):
 | |
|     def __init__(self, message: str, rule: Optional["MergeRule"] = None) -> None:
 | |
|         super().__init__(message)
 | |
|         self.rule = rule
 | |
| 
 | |
| 
 | |
| class MandatoryChecksMissingError(MergeRuleFailedError):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class PostCommentError(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class MergeRule:
 | |
|     name: str
 | |
|     patterns: list[str]
 | |
|     approved_by: list[str]
 | |
|     mandatory_checks_name: Optional[list[str]]
 | |
|     ignore_flaky_failures: bool = True
 | |
| 
 | |
| 
 | |
| def gen_new_issue_link(
 | |
|     org: str, project: str, labels: list[str], template: str = "bug-report.yml"
 | |
| ) -> str:
 | |
|     labels_str = ",".join(labels)
 | |
|     return (
 | |
|         f"https://github.com/{org}/{project}/issues/new?"
 | |
|         f"labels={urllib.parse.quote(labels_str)}&"
 | |
|         f"template={urllib.parse.quote(template)}"
 | |
|     )
 | |
| 
 | |
| 
 | |
| def read_merge_rules(
 | |
|     repo: Optional[GitRepo], org: str, project: str
 | |
| ) -> list[MergeRule]:
 | |
|     """Returns the list of all merge rules for the repo or project.
 | |
| 
 | |
|     NB: this function is used in Meta-internal workflows, see the comment
 | |
|     at the top of this file for details.
 | |
|     """
 | |
|     repo_relative_rules_path = MERGE_RULE_PATH
 | |
|     if repo is None:
 | |
|         json_data = gh_fetch_url(
 | |
|             f"https://api.github.com/repos/{org}/{project}/contents/{repo_relative_rules_path}",
 | |
|             headers={"Accept": "application/vnd.github.v3+json"},
 | |
|             reader=json.load,
 | |
|         )
 | |
|         content = base64.b64decode(json_data["content"])
 | |
|         return [MergeRule(**x) for x in yaml.safe_load(content)]
 | |
|     else:
 | |
|         rules_path = Path(repo.repo_dir) / repo_relative_rules_path
 | |
|         if not rules_path.exists():
 | |
|             print(f"{rules_path} does not exist, returning empty rules")
 | |
|             return []
 | |
|         with open(rules_path) as fp:
 | |
|             rc = yaml.safe_load(fp)
 | |
|         return [MergeRule(**x) for x in rc]
 | |
| 
 | |
| 
 | |
| def find_matching_merge_rule(
 | |
|     pr: GitHubPR,
 | |
|     repo: Optional[GitRepo] = None,
 | |
|     skip_mandatory_checks: bool = False,
 | |
|     skip_internal_checks: bool = False,
 | |
|     ignore_current_checks: Optional[list[str]] = None,
 | |
| ) -> tuple[
 | |
|     MergeRule,
 | |
|     list[tuple[str, Optional[str], Optional[int]]],
 | |
|     list[tuple[str, Optional[str], Optional[int]]],
 | |
|     dict[str, list[Any]],
 | |
| ]:
 | |
|     """
 | |
|     Returns merge rule matching to this pr together with the list of associated pending
 | |
|     and failing jobs OR raises an exception.
 | |
| 
 | |
|     NB: this function is used in Meta-internal workflows, see the comment at the top of
 | |
|     this file for details.
 | |
|     """
 | |
|     changed_files = pr.get_changed_files()
 | |
|     approved_by = set(pr.get_approved_by())
 | |
| 
 | |
|     issue_link = gen_new_issue_link(
 | |
|         org=pr.org,
 | |
|         project=pr.project,
 | |
|         labels=["module: ci"],
 | |
|     )
 | |
|     reject_reason = f"No rule found to match PR. Please [report]{issue_link} this issue to DevX team."
 | |
| 
 | |
|     rules = read_merge_rules(repo, pr.org, pr.project)
 | |
|     if not rules:
 | |
|         reject_reason = f"Rejecting the merge as no rules are defined for the repository in {MERGE_RULE_PATH}"
 | |
|         raise RuntimeError(reject_reason)
 | |
| 
 | |
|     checks = pr.get_checkrun_conclusions()
 | |
|     checks = get_classifications(
 | |
|         pr.pr_num,
 | |
|         pr.project,
 | |
|         checks,
 | |
|         ignore_current_checks=ignore_current_checks,
 | |
|     )
 | |
| 
 | |
|     # This keeps the list of all approvers that could stamp the change
 | |
|     all_rule_approvers = {}
 | |
| 
 | |
|     # PRs can fail multiple merge rules, but it only needs to pass one rule to be approved.
 | |
|     # If it fails all rules, we need to find the rule that it came closest to passing and report
 | |
|     # that to the dev.
 | |
|     #
 | |
|     # reject_reason_score ranks rules by relevancy. The higher the score, the more relevant the
 | |
|     # rule & rejection reason, and we only care about the most relevant rule/reason
 | |
|     #
 | |
|     # reject_reason_score intrepretation:
 | |
|     # Score 0 to 10K - how many files rule matched
 | |
|     # Score 10K - matched all files, but no overlapping approvers
 | |
|     # Score 20K - matched all files and approvers, but mandatory checks are pending
 | |
|     # Score 30k - Matched all files and approvers, but mandatory checks failed
 | |
|     reject_reason_score = 0
 | |
|     for rule in rules:
 | |
|         rule_name = rule.name
 | |
|         patterns_re = patterns_to_regex(rule.patterns)
 | |
|         non_matching_files = []
 | |
| 
 | |
|         # Does this rule apply to all the files?
 | |
|         for fname in changed_files:
 | |
|             if not patterns_re.match(fname):
 | |
|                 non_matching_files.append(fname)
 | |
|         if len(non_matching_files) > 0:
 | |
|             num_matching_files = len(changed_files) - len(non_matching_files)
 | |
|             if num_matching_files > reject_reason_score:
 | |
|                 reject_reason_score = num_matching_files
 | |
|                 reject_reason = "\n".join(
 | |
|                     (
 | |
|                         f"Not all files match rule `{rule_name}`.",
 | |
|                         f"{num_matching_files} files matched, but there are still non-matching files:",
 | |
|                         f"{','.join(non_matching_files[:5])}{', ...' if len(non_matching_files) > 5 else ''}",
 | |
|                     )
 | |
|                 )
 | |
|             continue
 | |
| 
 | |
|         # If rule needs approvers but PR has not been reviewed, skip it
 | |
|         if len(rule.approved_by) > 0 and len(approved_by) == 0:
 | |
|             if reject_reason_score < 10000:
 | |
|                 reject_reason_score = 10000
 | |
|                 reject_reason = f"PR #{pr.pr_num} has not been reviewed yet"
 | |
|             continue
 | |
| 
 | |
|         # Does the PR have the required approvals for this rule?
 | |
|         rule_approvers = set()
 | |
|         for approver in rule.approved_by:
 | |
|             if "/" in approver:
 | |
|                 org, name = approver.split("/")
 | |
|                 rule_approvers.update(gh_get_team_members(org, name))
 | |
|             else:
 | |
|                 rule_approvers.add(approver)
 | |
|         approvers_intersection = approved_by.intersection(rule_approvers)
 | |
|         # If rule requires approvers but they aren't the ones that reviewed PR
 | |
|         if len(approvers_intersection) == 0 and len(rule_approvers) > 0:
 | |
|             # Less than or equal is intentionally used here to gather all potential
 | |
|             # approvers
 | |
|             if reject_reason_score <= 10000:
 | |
|                 reject_reason_score = 10000
 | |
| 
 | |
|                 all_rule_approvers[rule.name] = rule.approved_by
 | |
|                 # Prepare the reject reason
 | |
|                 all_rule_approvers_msg = [
 | |
|                     f"- {name} ({', '.join(approved_by[:5])}{', ...' if len(approved_by) > 5 else ''})"
 | |
|                     for name, approved_by in all_rule_approvers.items()
 | |
|                 ]
 | |
| 
 | |
|                 reject_reason = "Approvers from one of the following sets are needed:\n"
 | |
|                 reject_reason += "\n".join(all_rule_approvers_msg)
 | |
| 
 | |
|             continue
 | |
| 
 | |
|         # Does the PR pass the checks required by this rule?
 | |
|         mandatory_checks = (
 | |
|             rule.mandatory_checks_name if rule.mandatory_checks_name is not None else []
 | |
|         )
 | |
|         required_checks = list(
 | |
|             filter(
 | |
|                 lambda x: ("EasyCLA" in x)
 | |
|                 or ("Facebook CLA Check" in x)
 | |
|                 or not skip_mandatory_checks,
 | |
|                 mandatory_checks,
 | |
|             )
 | |
|         )
 | |
|         pending_checks, failed_checks, _ = categorize_checks(
 | |
|             checks,
 | |
|             required_checks,
 | |
|             ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
 | |
|             if rule.ignore_flaky_failures
 | |
|             else 0,
 | |
|         )
 | |
| 
 | |
|         # categorize_checks assumes all tests are required if required_checks is empty.
 | |
|         # this is a workaround as we want to keep that behavior for categorize_checks
 | |
|         # generally.
 | |
|         if not required_checks:
 | |
|             pending_checks = []
 | |
|             failed_checks = []
 | |
| 
 | |
|         hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit_sha()}"
 | |
|         if len(failed_checks) > 0:
 | |
|             if reject_reason_score < 30000:
 | |
|                 reject_reason_score = 30000
 | |
|                 reject_reason = "\n".join(
 | |
|                     (
 | |
|                         f"{len(failed_checks)} mandatory check(s) failed.  The first few are:",
 | |
|                         *checks_to_markdown_bullets(failed_checks),
 | |
|                         "",
 | |
|                         f"Dig deeper by [viewing the failures on hud]({hud_link})",
 | |
|                     )
 | |
|                 )
 | |
|             continue
 | |
|         elif len(pending_checks) > 0:
 | |
|             if reject_reason_score < 20000:
 | |
|                 reject_reason_score = 20000
 | |
|                 reject_reason = "\n".join(
 | |
|                     (
 | |
|                         f"{len(pending_checks)} mandatory check(s) are pending/not yet run.  The first few are:",
 | |
|                         *checks_to_markdown_bullets(pending_checks),
 | |
|                         "",
 | |
|                         f"Dig deeper by [viewing the pending checks on hud]({hud_link})",
 | |
|                     )
 | |
|                 )
 | |
|             continue
 | |
| 
 | |
|         if not skip_internal_checks and pr.has_internal_changes():
 | |
|             raise RuntimeError(
 | |
|                 "This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!"
 | |
|             )
 | |
| 
 | |
|         # Categorize all checks when skip_mandatory_checks (force merge) is set. Do it here
 | |
|         # where the list of checks is readily available. These records will be saved into
 | |
|         # s3 merge records
 | |
|         (
 | |
|             pending_mandatory_checks,
 | |
|             failed_mandatory_checks,
 | |
|             ignorable_checks,
 | |
|         ) = categorize_checks(
 | |
|             checks,
 | |
|             [],
 | |
|             ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
 | |
|         )
 | |
|         return (
 | |
|             rule,
 | |
|             pending_mandatory_checks,
 | |
|             failed_mandatory_checks,
 | |
|             ignorable_checks,
 | |
|         )
 | |
| 
 | |
|     if reject_reason_score == 20000:
 | |
|         raise MandatoryChecksMissingError(reject_reason, rule)
 | |
|     raise MergeRuleFailedError(reject_reason, rule)
 | |
| 
 | |
| 
 | |
| def checks_to_str(checks: list[tuple[str, Optional[str]]]) -> str:
 | |
|     return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
 | |
| 
 | |
| 
 | |
| def checks_to_markdown_bullets(
 | |
|     checks: list[tuple[str, Optional[str], Optional[int]]],
 | |
| ) -> list[str]:
 | |
|     return [
 | |
|         f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def post_starting_merge_comment(
 | |
|     repo: GitRepo,
 | |
|     pr: GitHubPR,
 | |
|     explainer: TryMergeExplainer,
 | |
|     dry_run: bool,
 | |
|     ignore_current_checks_info: Optional[
 | |
|         list[tuple[str, Optional[str], Optional[int]]]
 | |
|     ] = None,
 | |
| ) -> None:
 | |
|     """Post the initial merge starting message on the PR. Also post a short
 | |
|     message on all PRs in the stack."""
 | |
|     gh_post_pr_comment(
 | |
|         pr.org,
 | |
|         pr.project,
 | |
|         pr.pr_num,
 | |
|         explainer.get_merge_message(ignore_current_checks_info),
 | |
|         dry_run=dry_run,
 | |
|     )
 | |
|     if pr.is_ghstack_pr():
 | |
|         for additional_prs, _ in get_ghstack_prs(repo, pr):
 | |
|             if additional_prs.pr_num != pr.pr_num:
 | |
|                 gh_post_pr_comment(
 | |
|                     additional_prs.org,
 | |
|                     additional_prs.project,
 | |
|                     additional_prs.pr_num,
 | |
|                     f"Starting merge as part of PR stack under #{pr.pr_num}",
 | |
|                     dry_run=dry_run,
 | |
|                 )
 | |
| 
 | |
| 
 | |
| def manually_close_merged_pr(
 | |
|     pr: GitHubPR,
 | |
|     additional_merged_prs: list[GitHubPR],
 | |
|     merge_commit_sha: str,
 | |
|     dry_run: bool,
 | |
| ) -> None:
 | |
|     def _comment_and_close(pr: GitHubPR, comment: str) -> None:
 | |
|         pr = GitHubPR(pr.org, pr.project, pr.pr_num)  # Refresh the PR
 | |
|         if not pr.is_closed():
 | |
|             gh_post_pr_comment(pr.org, pr.project, pr.pr_num, comment, dry_run)
 | |
|             gh_close_pr(pr.org, pr.project, pr.pr_num, dry_run)
 | |
| 
 | |
|     message = (
 | |
|         f"This PR (#{pr.pr_num}) was merged in {merge_commit_sha} but it is still open, likely due to a Github bug, "
 | |
|         "so mergebot is closing it manually.  If you think this is a mistake, please feel free to reopen and contact Dev Infra."
 | |
|     )
 | |
|     _comment_and_close(pr, message)
 | |
|     for additional_pr in additional_merged_prs:
 | |
|         message = (
 | |
|             f"This PR (#{additional_pr.pr_num}) was merged as part of PR #{pr.pr_num} in the stack under {merge_commit_sha} "
 | |
|             "but it is still open, likely due to a Github bug, so mergebot is closing it manually. "
 | |
|             "If you think this is a mistake, please feel free to reopen and contact Dev Infra."
 | |
|         )
 | |
|         _comment_and_close(additional_pr, message)
 | |
| 
 | |
|     print(f"PR {pr.pr_num} and all additional PRs in the stack have been closed.")
 | |
| 
 | |
| 
 | |
| @retries_decorator()
 | |
| def save_merge_record(
 | |
|     comment_id: int,
 | |
|     pr_num: int,
 | |
|     owner: str,
 | |
|     project: str,
 | |
|     author: str,
 | |
|     pending_checks: list[tuple[str, Optional[str], Optional[int]]],
 | |
|     failed_checks: list[tuple[str, Optional[str], Optional[int]]],
 | |
|     ignore_current_checks: list[tuple[str, Optional[str], Optional[int]]],
 | |
|     broken_trunk_checks: list[tuple[str, Optional[str], Optional[int]]],
 | |
|     flaky_checks: list[tuple[str, Optional[str], Optional[int]]],
 | |
|     unstable_checks: list[tuple[str, Optional[str], Optional[int]]],
 | |
|     last_commit_sha: str,
 | |
|     merge_base_sha: str,
 | |
|     merge_commit_sha: str = "",
 | |
|     is_failed: bool = False,
 | |
|     skip_mandatory_checks: bool = False,
 | |
|     ignore_current: bool = False,
 | |
|     error: str = "",
 | |
| ) -> None:
 | |
|     """
 | |
|     This saves the merge records as a json, which can later be uploaded to s3
 | |
|     """
 | |
| 
 | |
|     # Prepare the record to be written into s3
 | |
|     data = [
 | |
|         {
 | |
|             "comment_id": comment_id,
 | |
|             "pr_num": pr_num,
 | |
|             "owner": owner,
 | |
|             "project": project,
 | |
|             "author": author,
 | |
|             "pending_checks": pending_checks,
 | |
|             "failed_checks": failed_checks,
 | |
|             "ignore_current_checks": ignore_current_checks,
 | |
|             "broken_trunk_checks": broken_trunk_checks,
 | |
|             "flaky_checks": flaky_checks,
 | |
|             "unstable_checks": unstable_checks,
 | |
|             "last_commit_sha": last_commit_sha,
 | |
|             "merge_base_sha": merge_base_sha,
 | |
|             "merge_commit_sha": merge_commit_sha,
 | |
|             "is_failed": is_failed,
 | |
|             "skip_mandatory_checks": skip_mandatory_checks,
 | |
|             "ignore_current": ignore_current,
 | |
|             "error": error,
 | |
|             # This is a unique identifier for the record for deduping purposes
 | |
|             # in Rockset.  Any unique string would work.  This will not be used
 | |
|             # after we migrate off Rockset
 | |
|             "_id": f"{project}-{pr_num}-{comment_id}-{os.environ.get('GITHUB_RUN_ID')}",
 | |
|         }
 | |
|     ]
 | |
|     repo_root = Path(__file__).resolve().parent.parent.parent
 | |
| 
 | |
|     with open(repo_root / "merge_record.json", "w") as f:
 | |
|         json.dump(data, f)
 | |
| 
 | |
| 
 | |
| @retries_decorator()
 | |
| def get_drci_classifications(pr_num: int, project: str = "pytorch") -> Any:
 | |
|     """
 | |
|     Query HUD API to find similar failures to decide if they are flaky
 | |
|     """
 | |
|     # NB: This doesn't work internally atm because this requires making an
 | |
|     # external API call to HUD
 | |
|     failures = gh_fetch_url(
 | |
|         f"https://hud.pytorch.org/api/drci/drci?prNumber={pr_num}",
 | |
|         data=f"repo={project}",
 | |
|         headers={
 | |
|             "Authorization": os.getenv("DRCI_BOT_KEY", ""),
 | |
|             "Accept": "application/vnd.github.v3+json",
 | |
|         },
 | |
|         method="POST",
 | |
|         reader=json.load,
 | |
|     )
 | |
| 
 | |
|     return failures.get(str(pr_num), {}) if failures else {}
 | |
| 
 | |
| 
 | |
| REMOVE_JOB_NAME_SUFFIX_REGEX = re.compile(r", [0-9]+, [0-9]+, .+\)$")
 | |
| 
 | |
| 
 | |
| def remove_job_name_suffix(name: str, replacement: str = ")") -> str:
 | |
|     return re.sub(REMOVE_JOB_NAME_SUFFIX_REGEX, replacement, name)
 | |
| 
 | |
| 
 | |
| def is_broken_trunk(
 | |
|     check: JobCheckState,
 | |
|     drci_classifications: Any,
 | |
| ) -> bool:
 | |
|     if not check or not drci_classifications:
 | |
|         return False
 | |
| 
 | |
|     name = check.name
 | |
|     job_id = check.job_id
 | |
| 
 | |
|     # Consult the list of broken trunk failures from Dr.CI
 | |
|     return any(
 | |
|         (name == broken_trunk["name"]) or (job_id and job_id == broken_trunk["id"])
 | |
|         for broken_trunk in drci_classifications.get("BROKEN_TRUNK", [])
 | |
|     )
 | |
| 
 | |
| 
 | |
| def is_unstable(
 | |
|     check: JobCheckState,
 | |
|     drci_classifications: Any,
 | |
| ) -> bool:
 | |
|     if not check or not drci_classifications:
 | |
|         return False
 | |
| 
 | |
|     name = check.name
 | |
|     job_id = check.job_id
 | |
| 
 | |
|     # The job name has the unstable keyword. This is the original way to mark a job
 | |
|     # as unstable on HUD, Dr.CI, and trymerge
 | |
|     if "unstable" in name:
 | |
|         return True
 | |
| 
 | |
|     # Consult the list of unstable failures from Dr.CI
 | |
|     return any(
 | |
|         (name == unstable["name"] or (job_id and job_id == unstable["id"]))
 | |
|         for unstable in drci_classifications.get("UNSTABLE", [])
 | |
|     )
 | |
| 
 | |
| 
 | |
| def is_flaky(
 | |
|     check: JobCheckState,
 | |
|     drci_classifications: Any,
 | |
| ) -> bool:
 | |
|     if not check or not drci_classifications:
 | |
|         return False
 | |
| 
 | |
|     name = check.name
 | |
|     job_id = check.job_id
 | |
| 
 | |
|     # Consult the list of flaky failures from Dr.CI
 | |
|     return any(
 | |
|         (name == flaky["name"] or (job_id and job_id == flaky["id"]))
 | |
|         for flaky in drci_classifications.get("FLAKY", [])
 | |
|     )
 | |
| 
 | |
| 
 | |
| def is_invalid_cancel(
 | |
|     name: str,
 | |
|     conclusion: Optional[str],
 | |
|     drci_classifications: Any,
 | |
| ) -> bool:
 | |
|     """
 | |
|     After https://github.com/pytorch/test-infra/pull/4579, invalid cancelled
 | |
|     signals have been removed from HUD and Dr.CI. The same needs to be done
 | |
|     here for consistency
 | |
|     """
 | |
|     if (
 | |
|         not name
 | |
|         or not drci_classifications
 | |
|         or not conclusion
 | |
|         or conclusion.upper() != "CANCELLED"
 | |
|     ):
 | |
|         return False
 | |
| 
 | |
|     # If a job is cancelled and not listed as a failure by Dr.CI, it's an
 | |
|     # invalid signal and can be ignored
 | |
|     return all(
 | |
|         name != failure["name"] for failure in drci_classifications.get("FAILED", [])
 | |
|     )
 | |
| 
 | |
| 
 | |
| def get_classifications(
 | |
|     pr_num: int,
 | |
|     project: str,
 | |
|     checks: dict[str, JobCheckState],
 | |
|     ignore_current_checks: Optional[list[str]],
 | |
| ) -> dict[str, JobCheckState]:
 | |
|     # Get the failure classification from Dr.CI, which is the source of truth
 | |
|     # going forward. It's preferable to try calling Dr.CI API directly first
 | |
|     # to get the latest results as well as update Dr.CI PR comment
 | |
|     drci_classifications = get_drci_classifications(pr_num=pr_num, project=project)
 | |
| 
 | |
|     def get_readable_drci_results(drci_classifications: Any) -> str:
 | |
|         try:
 | |
|             s = f"From Dr.CI API ({pr_num}):\n"
 | |
|             for classification, jobs in drci_classifications.items():
 | |
|                 s += f"  {classification}: \n"
 | |
|                 for job in jobs:
 | |
|                     s += f"    {job['id']} {job['name']}\n"
 | |
|             return s
 | |
|         except Exception:
 | |
|             return f"From Dr.CI API: {json.dumps(drci_classifications)}"
 | |
| 
 | |
|     print(get_readable_drci_results(drci_classifications))
 | |
| 
 | |
|     # NB: if the latest results from Dr.CI is not available, i.e. when calling from
 | |
|     # SandCastle, we fallback to any results we can find on Dr.CI check run summary
 | |
|     if (
 | |
|         not drci_classifications
 | |
|         and DRCI_CHECKRUN_NAME in checks
 | |
|         and checks[DRCI_CHECKRUN_NAME]
 | |
|         and checks[DRCI_CHECKRUN_NAME].summary
 | |
|     ):
 | |
|         drci_summary = checks[DRCI_CHECKRUN_NAME].summary
 | |
|         try:
 | |
|             print(f"From Dr.CI checkrun summary: {drci_summary}")
 | |
|             drci_classifications = json.loads(str(drci_summary))
 | |
|         except json.JSONDecodeError:
 | |
|             warn("Invalid Dr.CI checkrun summary")
 | |
|             drci_classifications = {}
 | |
| 
 | |
|     checks_with_classifications = checks.copy()
 | |
|     for name, check in checks.items():
 | |
|         if check.status == "SUCCESS" or check.status == "NEUTRAL":
 | |
|             continue
 | |
| 
 | |
|         if is_unstable(check, drci_classifications):
 | |
|             checks_with_classifications[name] = JobCheckState(
 | |
|                 check.name,
 | |
|                 check.url,
 | |
|                 check.status,
 | |
|                 "UNSTABLE",
 | |
|                 check.job_id,
 | |
|                 check.title,
 | |
|                 check.summary,
 | |
|             )
 | |
|             continue
 | |
| 
 | |
|         # NB: It's important to note that when it comes to ghstack and broken trunk classification,
 | |
|         # Dr.CI uses the base of the whole stack
 | |
|         if is_broken_trunk(check, drci_classifications):
 | |
|             checks_with_classifications[name] = JobCheckState(
 | |
|                 check.name,
 | |
|                 check.url,
 | |
|                 check.status,
 | |
|                 "BROKEN_TRUNK",
 | |
|                 check.job_id,
 | |
|                 check.title,
 | |
|                 check.summary,
 | |
|             )
 | |
|             continue
 | |
| 
 | |
|         elif is_flaky(check, drci_classifications):
 | |
|             checks_with_classifications[name] = JobCheckState(
 | |
|                 check.name,
 | |
|                 check.url,
 | |
|                 check.status,
 | |
|                 "FLAKY",
 | |
|                 check.job_id,
 | |
|                 check.title,
 | |
|                 check.summary,
 | |
|             )
 | |
|             continue
 | |
| 
 | |
|         elif is_invalid_cancel(name, check.status, drci_classifications):
 | |
|             # NB: Create a new category here for invalid cancelled signals because
 | |
|             # there are usually many of them when they happen. So, they shouldn't
 | |
|             # be counted toward ignorable failures threshold
 | |
|             checks_with_classifications[name] = JobCheckState(
 | |
|                 check.name,
 | |
|                 check.url,
 | |
|                 check.status,
 | |
|                 "INVALID_CANCEL",
 | |
|                 check.job_id,
 | |
|                 check.title,
 | |
|                 check.summary,
 | |
|             )
 | |
|             continue
 | |
| 
 | |
|         if ignore_current_checks is not None and name in ignore_current_checks:
 | |
|             checks_with_classifications[name] = JobCheckState(
 | |
|                 check.name,
 | |
|                 check.url,
 | |
|                 check.status,
 | |
|                 "IGNORE_CURRENT_CHECK",
 | |
|                 check.job_id,
 | |
|                 check.title,
 | |
|                 check.summary,
 | |
|             )
 | |
| 
 | |
|     return checks_with_classifications
 | |
| 
 | |
| 
 | |
| def filter_checks_with_lambda(
 | |
|     checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
 | |
| ) -> list[JobCheckState]:
 | |
|     return [check for check in checks.values() if status_filter(check.status)]
 | |
| 
 | |
| 
 | |
| def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
 | |
|     commit_sha = pr.get_merge_commit()
 | |
|     if commit_sha is not None:
 | |
|         return commit_sha
 | |
|     commits = repo.commits_resolving_gh_pr(pr.pr_num)
 | |
|     if len(commits) == 0:
 | |
|         raise PostCommentError("Can't find any commits resolving PR")
 | |
|     return commits[0]
 | |
| 
 | |
| 
 | |
| def validate_revert(
 | |
|     repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
 | |
| ) -> tuple[str, str]:
 | |
|     comment = (
 | |
|         pr.get_last_comment()
 | |
|         if comment_id is None
 | |
|         else pr.get_comment_by_id(comment_id)
 | |
|     )
 | |
|     if comment.editor_login is not None:
 | |
|         raise PostCommentError(
 | |
|             "Halting the revert as the revert comment has been edited."
 | |
|         )
 | |
|     author_association = comment.author_association
 | |
|     author_login = comment.author_login
 | |
|     allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"]
 | |
|     # For some reason, one can not be a member of private repo, only CONTRIBUTOR
 | |
|     if pr.is_base_repo_private():
 | |
|         allowed_reverters.append("CONTRIBUTOR")
 | |
|     if author_association not in allowed_reverters:
 | |
|         raise PostCommentError(
 | |
|             f"Will not revert as @{author_login} is not one of "
 | |
|             f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
 | |
|         )
 | |
| 
 | |
|     # Raises exception if matching rule is not found, but ignores all status checks
 | |
|     find_matching_merge_rule(
 | |
|         pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
 | |
|     )
 | |
|     commit_sha = get_pr_commit_sha(repo, pr)
 | |
|     return (author_login, commit_sha)
 | |
| 
 | |
| 
 | |
| def get_ghstack_dependent_prs(
 | |
|     repo: GitRepo, pr: GitHubPR, only_closed: bool = True
 | |
| ) -> list[tuple[str, GitHubPR]]:
 | |
|     """
 | |
|     Get the PRs in the stack that are above this PR (inclusive).
 | |
|     Throws error if stack have branched or original branches are gone
 | |
|     """
 | |
|     assert pr.is_ghstack_pr()
 | |
|     orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
 | |
|     rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
 | |
|     if len(rev_list) == 0:
 | |
|         raise RuntimeError(
 | |
|             f"PR {pr.pr_num} does not have any revisions associated with it"
 | |
|         )
 | |
|     skip_len = len(rev_list) - 1
 | |
|     for branch in repo.branches_containing_ref(orig_ref):
 | |
|         candidate = repo.revlist(f"{pr.default_branch()}..{branch}")
 | |
|         # Pick longest candidate
 | |
|         if len(candidate) > len(rev_list):
 | |
|             candidate, rev_list = rev_list, candidate
 | |
|         # Validate that candidate always ends rev-list
 | |
|         if rev_list[-len(candidate) :] != candidate:
 | |
|             raise RuntimeError(
 | |
|                 f"Branch {branch} revlist {', '.join(candidate)} is not a subset of {', '.join(rev_list)}"
 | |
|             )
 | |
|     # Remove commits original PR depends on
 | |
|     if skip_len > 0:
 | |
|         rev_list = rev_list[:-skip_len]
 | |
|     rc: list[tuple[str, GitHubPR]] = []
 | |
|     for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
 | |
|         if not pr_.is_closed():
 | |
|             if not only_closed:
 | |
|                 rc.append(("", pr_))
 | |
|             continue
 | |
|         commit_sha = get_pr_commit_sha(repo, pr_)
 | |
|         rc.append((commit_sha, pr_))
 | |
|     return rc
 | |
| 
 | |
| 
 | |
| def do_revert_prs(
 | |
|     repo: GitRepo,
 | |
|     original_pr: GitHubPR,
 | |
|     shas_and_prs: list[tuple[str, GitHubPR]],
 | |
|     *,
 | |
|     author_login: str,
 | |
|     extra_msg: str = "",
 | |
|     skip_internal_checks: bool = False,
 | |
|     dry_run: bool = False,
 | |
| ) -> None:
 | |
|     # Prepare and push revert commits
 | |
|     for commit_sha, pr in shas_and_prs:
 | |
|         revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
 | |
|         revert_msg += extra_msg
 | |
|         repo.checkout(pr.default_branch())
 | |
|         repo.revert(commit_sha)
 | |
|         msg = repo.commit_message("HEAD")
 | |
|         msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
 | |
|         msg += revert_msg
 | |
|         repo.amend_commit_message(msg)
 | |
|     repo.push(shas_and_prs[0][1].default_branch(), dry_run)
 | |
| 
 | |
|     # Comment/reopen PRs
 | |
|     for commit_sha, pr in shas_and_prs:
 | |
|         revert_message = ""
 | |
|         if pr.pr_num == original_pr.pr_num:
 | |
|             revert_message += (
 | |
|                 f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
 | |
|             )
 | |
|         else:
 | |
|             revert_message += (
 | |
|                 f"@{pr.get_pr_creator_login()} your PR has been reverted as part of the stack under "
 | |
|                 f"#{original_pr.pr_num}.\n"
 | |
|             )
 | |
|         if (
 | |
|             pr.has_internal_changes()
 | |
|             and not pr.has_no_connected_diff()
 | |
|             and not skip_internal_checks
 | |
|         ):
 | |
|             revert_message += "\n:warning: This PR might contain internal changes"
 | |
|             revert_message += "\ncc: @pytorch/pytorch-dev-infra"
 | |
|         gh_post_pr_comment(
 | |
|             pr.org, pr.project, pr.pr_num, revert_message, dry_run=dry_run
 | |
|         )
 | |
| 
 | |
|         pr.add_numbered_label("reverted", dry_run)
 | |
|         pr.add_label("ci-no-td", dry_run)
 | |
|         if not dry_run:
 | |
|             gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
 | |
|             gh_update_pr_state(pr.org, pr.project, pr.pr_num)
 | |
| 
 | |
| 
 | |
| def try_revert(
 | |
|     repo: GitRepo,
 | |
|     pr: GitHubPR,
 | |
|     *,
 | |
|     dry_run: bool = False,
 | |
|     comment_id: Optional[int] = None,
 | |
|     reason: Optional[str] = None,
 | |
| ) -> None:
 | |
|     try:
 | |
|         author_login, commit_sha = validate_revert(repo, pr, comment_id=comment_id)
 | |
|     except PostCommentError as e:
 | |
|         gh_post_pr_comment(pr.org, pr.project, pr.pr_num, str(e), dry_run=dry_run)
 | |
|         return
 | |
| 
 | |
|     extra_msg = f" due to {reason}" if reason is not None else ""
 | |
|     extra_msg += (
 | |
|         f" ([comment]({pr.get_comment_by_id(comment_id).url}))\n"
 | |
|         if comment_id is not None
 | |
|         else "\n"
 | |
|     )
 | |
|     shas_and_prs = [(commit_sha, pr)]
 | |
|     if pr.is_ghstack_pr():
 | |
|         try:
 | |
|             shas_and_prs = get_ghstack_dependent_prs(repo, pr)
 | |
|             prs_to_revert = " ".join([t[1].get_pr_url() for t in shas_and_prs])
 | |
|             print(f"About to stack of PRs: {prs_to_revert}")
 | |
|         except Exception as e:
 | |
|             print(
 | |
|                 f"Failed to fetch dependent PRs: {str(e)}, fall over to single revert"
 | |
|             )
 | |
| 
 | |
|     do_revert_prs(
 | |
|         repo,
 | |
|         pr,
 | |
|         shas_and_prs,
 | |
|         author_login=author_login,
 | |
|         extra_msg=extra_msg,
 | |
|         dry_run=dry_run,
 | |
|         skip_internal_checks=can_skip_internal_checks(pr, comment_id),
 | |
|     )
 | |
| 
 | |
| 
 | |
| def prefix_with_github_url(suffix_str: str) -> str:
 | |
|     return f"https://github.com/{suffix_str}"
 | |
| 
 | |
| 
 | |
| def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
 | |
|     if skip_mandatory_checks:
 | |
|         return
 | |
|     response = cast(
 | |
|         dict[str, Any],
 | |
|         gh_fetch_json_list(
 | |
|             "https://api.github.com/search/issues",  # @lint-ignore
 | |
|             # Having two label: queries is an AND operation
 | |
|             params={
 | |
|                 "q": f'repo:{org}/{project} is:open is:issue label:"ci: sev" label:"merge blocking"'
 | |
|             },
 | |
|         ),
 | |
|     )
 | |
|     if response["total_count"] != 0:
 | |
|         raise RuntimeError(
 | |
|             "Not merging any PRs at the moment because there is a "
 | |
|             + "merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: \n"
 | |
|             + f"{response['items'][0]['html_url']}"
 | |
|         )
 | |
|     return
 | |
| 
 | |
| 
 | |
| def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
 | |
|     return len(list(filter(pattern.match, labels))) > 0
 | |
| 
 | |
| 
 | |
| def categorize_checks(
 | |
|     check_runs: JobNameToStateDict,
 | |
|     required_checks: list[str],
 | |
|     ok_failed_checks_threshold: Optional[int] = None,
 | |
| ) -> tuple[
 | |
|     list[tuple[str, Optional[str], Optional[int]]],
 | |
|     list[tuple[str, Optional[str], Optional[int]]],
 | |
|     dict[str, list[Any]],
 | |
| ]:
 | |
|     """
 | |
|     Categories all jobs into the list of pending and failing jobs. All known flaky
 | |
|     failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
 | |
|     is not set (unlimited)
 | |
|     """
 | |
|     pending_checks: list[tuple[str, Optional[str], Optional[int]]] = []
 | |
|     failed_checks: list[tuple[str, Optional[str], Optional[int]]] = []
 | |
| 
 | |
|     # failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3
 | |
|     failed_checks_categorization: dict[str, list[Any]] = defaultdict(list)
 | |
| 
 | |
|     # If required_checks is not set or empty, consider all names are relevant
 | |
|     relevant_checknames = [
 | |
|         name
 | |
|         for name in check_runs.keys()
 | |
|         if not required_checks or any(x in name for x in required_checks)
 | |
|     ]
 | |
| 
 | |
|     for checkname in required_checks:
 | |
|         if all(checkname not in x for x in check_runs.keys()):
 | |
|             pending_checks.append((checkname, None, None))
 | |
| 
 | |
|     for checkname in relevant_checknames:
 | |
|         status = check_runs[checkname].status
 | |
|         url = check_runs[checkname].url
 | |
|         classification = check_runs[checkname].classification
 | |
|         job_id = check_runs[checkname].job_id
 | |
| 
 | |
|         if status is None and classification != "UNSTABLE":
 | |
|             # NB: No need to wait if the job classification is unstable as it would be
 | |
|             # ignored anyway. This is useful to not need to wait for scarce resources
 | |
|             # like ROCm, which is also frequently in unstable mode
 | |
|             pending_checks.append((checkname, url, job_id))
 | |
|         elif classification == "INVALID_CANCEL":
 | |
|             continue
 | |
|         elif not is_passing_status(check_runs[checkname].status):
 | |
|             target = (
 | |
|                 failed_checks_categorization[classification]
 | |
|                 if classification
 | |
|                 in ("IGNORE_CURRENT_CHECK", "BROKEN_TRUNK", "FLAKY", "UNSTABLE")
 | |
|                 else failed_checks
 | |
|             )
 | |
|             target.append((checkname, url, job_id))
 | |
| 
 | |
|     flaky_or_broken_trunk = (
 | |
|         failed_checks_categorization["BROKEN_TRUNK"]
 | |
|         + failed_checks_categorization["FLAKY"]
 | |
|     )
 | |
| 
 | |
|     if flaky_or_broken_trunk:
 | |
|         warn(
 | |
|             f"The following {len(flaky_or_broken_trunk)} checks failed but were likely due flakiness or broken trunk: "
 | |
|             + ", ".join([x[0] for x in flaky_or_broken_trunk])
 | |
|             + (
 | |
|                 f" but this is greater than the threshold of {ok_failed_checks_threshold} so merge will fail"
 | |
|                 if ok_failed_checks_threshold is not None
 | |
|                 and len(flaky_or_broken_trunk) > ok_failed_checks_threshold
 | |
|                 else ""
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     if (
 | |
|         ok_failed_checks_threshold is not None
 | |
|         and len(flaky_or_broken_trunk) > ok_failed_checks_threshold
 | |
|     ):
 | |
|         failed_checks = failed_checks + flaky_or_broken_trunk
 | |
| 
 | |
|     # The list of failed_checks_categorization is returned so that it can be saved into the s3 merge record
 | |
|     return (pending_checks, failed_checks, failed_checks_categorization)
 | |
| 
 | |
| 
 | |
| def merge(
 | |
|     pr: GitHubPR,
 | |
|     repo: GitRepo,
 | |
|     comment_id: int,
 | |
|     dry_run: bool = False,
 | |
|     skip_mandatory_checks: bool = False,
 | |
|     timeout_minutes: int = 400,
 | |
|     stale_pr_days: int = 3,
 | |
|     ignore_current: bool = False,
 | |
| ) -> None:
 | |
|     initial_commit_sha = pr.last_commit_sha()
 | |
|     pr_link = f"https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num}"
 | |
|     print(f"Attempting merge of {initial_commit_sha} ({pr_link})")
 | |
| 
 | |
|     if MERGE_IN_PROGRESS_LABEL not in pr.get_labels():
 | |
|         gh_add_labels(pr.org, pr.project, pr.pr_num, [MERGE_IN_PROGRESS_LABEL], dry_run)
 | |
| 
 | |
|     explainer = TryMergeExplainer(
 | |
|         skip_mandatory_checks,
 | |
|         pr.get_labels(),
 | |
|         pr.pr_num,
 | |
|         pr.org,
 | |
|         pr.project,
 | |
|         ignore_current,
 | |
|     )
 | |
| 
 | |
|     # probably a bad name, but this is a list of current checks that should be
 | |
|     # ignored and is toggled by the --ignore-current flag
 | |
|     ignore_current_checks_info = []
 | |
| 
 | |
|     if pr.is_ghstack_pr():
 | |
|         get_ghstack_prs(repo, pr)  # raises error if out of sync
 | |
| 
 | |
|     check_for_sev(pr.org, pr.project, skip_mandatory_checks)
 | |
| 
 | |
|     if skip_mandatory_checks:
 | |
|         post_starting_merge_comment(repo, pr, explainer, dry_run)
 | |
|         return pr.merge_into(
 | |
|             repo,
 | |
|             dry_run=dry_run,
 | |
|             skip_mandatory_checks=skip_mandatory_checks,
 | |
|             comment_id=comment_id,
 | |
|         )
 | |
| 
 | |
|     # Check for approvals
 | |
|     find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)
 | |
| 
 | |
|     if not has_required_labels(pr):
 | |
|         raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))
 | |
| 
 | |
|     if ignore_current:
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         _, failing, _ = categorize_checks(
 | |
|             checks,
 | |
|             list(checks.keys()),
 | |
|             ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
 | |
|         )
 | |
|         ignore_current_checks_info = failing
 | |
| 
 | |
|     post_starting_merge_comment(
 | |
|         repo,
 | |
|         pr,
 | |
|         explainer,
 | |
|         dry_run,
 | |
|         ignore_current_checks_info=ignore_current_checks_info,
 | |
|     )
 | |
| 
 | |
|     start_time = time.time()
 | |
|     last_exception = ""
 | |
|     elapsed_time = 0.0
 | |
|     ignore_current_checks = [
 | |
|         x[0] for x in ignore_current_checks_info
 | |
|     ]  # convert to List[str] for convenience
 | |
|     while elapsed_time < timeout_minutes * 60:
 | |
|         check_for_sev(pr.org, pr.project, skip_mandatory_checks)
 | |
|         current_time = time.time()
 | |
|         elapsed_time = current_time - start_time
 | |
|         print(
 | |
|             f"Attempting merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} ({elapsed_time / 60} minutes elapsed)"
 | |
|         )
 | |
|         pr = GitHubPR(pr.org, pr.project, pr.pr_num)
 | |
|         if initial_commit_sha != pr.last_commit_sha():
 | |
|             raise RuntimeError(
 | |
|                 "New commits were pushed while merging. Please rerun the merge command."
 | |
|             )
 | |
|         try:
 | |
|             required_checks = []
 | |
|             failed_rule_message = None
 | |
|             ignore_flaky_failures = True
 | |
|             try:
 | |
|                 find_matching_merge_rule(
 | |
|                     pr, repo, ignore_current_checks=ignore_current_checks
 | |
|                 )
 | |
|             except MandatoryChecksMissingError as ex:
 | |
|                 if ex.rule is not None:
 | |
|                     ignore_flaky_failures = ex.rule.ignore_flaky_failures
 | |
|                     if ex.rule.mandatory_checks_name is not None:
 | |
|                         required_checks = ex.rule.mandatory_checks_name
 | |
|                 failed_rule_message = ex
 | |
| 
 | |
|             checks = pr.get_checkrun_conclusions()
 | |
|             checks = get_classifications(
 | |
|                 pr.pr_num,
 | |
|                 pr.project,
 | |
|                 checks,
 | |
|                 ignore_current_checks=ignore_current_checks,
 | |
|             )
 | |
|             pending, failing, _ = categorize_checks(
 | |
|                 checks,
 | |
|                 required_checks
 | |
|                 + [x for x in checks.keys() if x not in required_checks],
 | |
|                 ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
 | |
|                 if ignore_flaky_failures
 | |
|                 else 0,
 | |
|             )
 | |
|             # HACK until GitHub will be better about surfacing those
 | |
|             startup_failures = filter_checks_with_lambda(
 | |
|                 checks, lambda status: status == "STARTUP_FAILURE"
 | |
|             )
 | |
|             if len(startup_failures) > 0:
 | |
|                 raise RuntimeError(
 | |
|                     f"{len(startup_failures)} STARTUP failures reported, please check workflows syntax! "
 | |
|                     + ", ".join(f"[{x.name}]({x.url})" for x in startup_failures[:5])
 | |
|                 )
 | |
|             # END of HACK
 | |
| 
 | |
|             if len(failing) > 0:
 | |
|                 raise RuntimeError(
 | |
|                     f"{len(failing)} jobs have failed, first few of them are: "
 | |
|                     + ", ".join(f"[{x[0]}]({x[1]})" for x in failing[:5])
 | |
|                 )
 | |
|             if len(pending) > 0:
 | |
|                 if failed_rule_message is not None:
 | |
|                     raise failed_rule_message
 | |
|                 else:
 | |
|                     raise MandatoryChecksMissingError(
 | |
|                         f"Still waiting for {len(pending)} jobs to finish, "
 | |
|                         + f"first few of them are: {', '.join(x[0] for x in pending[:5])}"
 | |
|                     )
 | |
| 
 | |
|             return pr.merge_into(
 | |
|                 repo,
 | |
|                 dry_run=dry_run,
 | |
|                 skip_mandatory_checks=skip_mandatory_checks,
 | |
|                 comment_id=comment_id,
 | |
|                 ignore_current_checks=ignore_current_checks,
 | |
|             )
 | |
|         except MandatoryChecksMissingError as ex:
 | |
|             last_exception = str(ex)
 | |
|             print(
 | |
|                 f"Merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} failed due to: {ex}. Retrying in 5 min",
 | |
|                 flush=True,
 | |
|             )
 | |
|             time.sleep(5 * 60)
 | |
|     # Finally report timeout back
 | |
|     msg = f"Merged timed out after {timeout_minutes} minutes. Please contact the pytorch_dev_infra team."
 | |
|     msg += f"The last exception was: {last_exception}"
 | |
|     gh_add_labels(pr.org, pr.project, pr.pr_num, ["land-failed"], dry_run)
 | |
|     raise RuntimeError(msg)
 | |
| 
 | |
| 
 | |
| def main() -> None:
 | |
|     args = parse_args()
 | |
|     repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
 | |
|     org, project = repo.gh_owner_and_name()
 | |
|     pr = GitHubPR(org, project, args.pr_num)
 | |
| 
 | |
|     def handle_exception(e: Exception, title: str = "Merge failed") -> None:
 | |
|         exception = f"**Reason**: {e}"
 | |
| 
 | |
|         failing_rule = None
 | |
|         if isinstance(e, MergeRuleFailedError):
 | |
|             failing_rule = e.rule.name if e.rule else None
 | |
| 
 | |
|         internal_debugging = ""
 | |
|         run_url = os.getenv("GH_RUN_URL")
 | |
|         if run_url is not None:
 | |
|             # Hide this behind a collapsed bullet since it's not helpful to most devs
 | |
|             internal_debugging = "\n".join(
 | |
|                 line
 | |
|                 for line in (
 | |
|                     "<details><summary>Details for Dev Infra team</summary>",
 | |
|                     f'Raised by <a href="{run_url}">workflow job</a>\n',
 | |
|                     f"Failing merge rule: {failing_rule}" if failing_rule else "",
 | |
|                     "</details>",
 | |
|                 )
 | |
|                 if line
 | |
|             )  # ignore empty lines during the join
 | |
| 
 | |
|         msg = "\n".join((f"## {title}", f"{exception}", "", f"{internal_debugging}"))
 | |
| 
 | |
|         gh_post_pr_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
 | |
|         import traceback
 | |
| 
 | |
|         traceback.print_exc()
 | |
| 
 | |
|     if args.revert:
 | |
|         try:
 | |
|             gh_post_pr_comment(
 | |
|                 org,
 | |
|                 project,
 | |
|                 args.pr_num,
 | |
|                 get_revert_message(org, project, pr.pr_num),
 | |
|                 args.dry_run,
 | |
|             )
 | |
|             try_revert(
 | |
|                 repo,
 | |
|                 pr,
 | |
|                 dry_run=args.dry_run,
 | |
|                 comment_id=args.comment_id,
 | |
|                 reason=args.reason,
 | |
|             )
 | |
|         except Exception as e:
 | |
|             handle_exception(e, f"Reverting PR {args.pr_num} failed")
 | |
|         return
 | |
| 
 | |
|     if pr.is_closed():
 | |
|         gh_post_pr_comment(
 | |
|             org,
 | |
|             project,
 | |
|             args.pr_num,
 | |
|             f"Can't merge closed PR #{args.pr_num}",
 | |
|             dry_run=args.dry_run,
 | |
|         )
 | |
|         return
 | |
| 
 | |
|     if pr.is_cross_repo() and pr.is_ghstack_pr():
 | |
|         gh_post_pr_comment(
 | |
|             org,
 | |
|             project,
 | |
|             args.pr_num,
 | |
|             "Cross-repo ghstack merges are not supported",
 | |
|             dry_run=args.dry_run,
 | |
|         )
 | |
|         return
 | |
|     if not pr.is_ghstack_pr() and pr.base_ref() != pr.default_branch():
 | |
|         gh_post_pr_comment(
 | |
|             org,
 | |
|             project,
 | |
|             args.pr_num,
 | |
|             f"PR targets {pr.base_ref()} rather than {pr.default_branch()}, refusing merge request",
 | |
|             dry_run=args.dry_run,
 | |
|         )
 | |
|         return
 | |
| 
 | |
|     if args.check_mergeability:
 | |
|         if pr.is_ghstack_pr():
 | |
|             get_ghstack_prs(repo, pr)  # raises error if out of sync
 | |
|         pr.merge_changes_locally(
 | |
|             repo,
 | |
|             skip_mandatory_checks=True,
 | |
|             skip_all_rule_checks=True,
 | |
|         )
 | |
|         return
 | |
| 
 | |
|     if not args.force and pr.has_invalid_submodule_updates():
 | |
|         message = (
 | |
|             f"This PR updates submodules {', '.join(pr.get_changed_submodules())}\n"
 | |
|         )
 | |
|         message += '\nIf those updates are intentional, please add "submodule" keyword to PR title/description.'
 | |
|         gh_post_pr_comment(org, project, args.pr_num, message, dry_run=args.dry_run)
 | |
|         return
 | |
|     try:
 | |
|         # Ensure comment id is set, else fail
 | |
|         if not args.comment_id:
 | |
|             raise ValueError(
 | |
|                 "Comment ID is required for merging PRs, please provide it using --comment-id"
 | |
|             )
 | |
| 
 | |
|         merge(
 | |
|             pr,
 | |
|             repo,
 | |
|             comment_id=args.comment_id,
 | |
|             dry_run=args.dry_run,
 | |
|             skip_mandatory_checks=args.force,
 | |
|             ignore_current=args.ignore_current,
 | |
|         )
 | |
|     except Exception as e:
 | |
|         handle_exception(e)
 | |
| 
 | |
|         if args.comment_id and args.pr_num:
 | |
|             # Finally, upload the record to s3, we don't have access to the
 | |
|             # list of pending and failed checks here, but they are not really
 | |
|             # needed at the moment
 | |
|             save_merge_record(
 | |
|                 comment_id=args.comment_id,
 | |
|                 pr_num=args.pr_num,
 | |
|                 owner=org,
 | |
|                 project=project,
 | |
|                 author=pr.get_author(),
 | |
|                 pending_checks=[],
 | |
|                 failed_checks=[],
 | |
|                 ignore_current_checks=[],
 | |
|                 broken_trunk_checks=[],
 | |
|                 flaky_checks=[],
 | |
|                 unstable_checks=[],
 | |
|                 last_commit_sha=pr.last_commit_sha(default=""),
 | |
|                 merge_base_sha=pr.get_merge_base(),
 | |
|                 is_failed=True,
 | |
|                 skip_mandatory_checks=args.force,
 | |
|                 ignore_current=args.ignore_current,
 | |
|                 error=str(e),
 | |
|             )
 | |
|         else:
 | |
|             print("Missing comment ID or PR number, couldn't upload to s3")
 | |
|     finally:
 | |
|         if not args.check_mergeability:
 | |
|             gh_remove_label(
 | |
|                 org, project, args.pr_num, MERGE_IN_PROGRESS_LABEL, args.dry_run
 | |
|             )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |