[BE] Clean up trymerge code handling broken trunk failures (#111520)

This is the final part of https://github.com/pytorch/pytorch/pull/110054.  The broken trunk classification has been done on Dr.CI, so we can just check for that in trymerge for consistency when ghstack is used.

* [x] https://github.com/pytorch/pytorch/pull/110054
* [x] https://github.com/pytorch/pytorch/pull/110133
* [x] This PR to clean up the broken trunk logic.

One important change is that `get_classifications` doesn't need to query the jobs from Rockset for the head and merge base SHA anymore, saving a query there.  The function looks a lot simpler now.

### Testing

https://github.com/pytorch/pytorch/pull/111253 had 1 broken trunk failure as detected by Dr.CI from the base commit 3eb5cae3af (valid) while trymerge didn't detect that because ghstack base commit be8e517174 didn't have the same failure (miss).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111520
Approved by: https://github.com/clee2000
This commit is contained in:
Huy Do
2023-10-19 02:30:53 +00:00
committed by PyTorch MergeBot
parent 4f0cf1e1ff
commit 4ec777e9a5
5 changed files with 68 additions and 262 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -12,7 +12,7 @@ import json
import os
import warnings
from hashlib import sha256
from typing import Any, cast, Dict, List, Optional
from typing import Any, Dict, List, Optional
from unittest import main, mock, skip, TestCase
from urllib.error import HTTPError
@ -27,7 +27,6 @@ from trymerge import (
gh_get_team_members,
gh_graphql,
GitHubPR,
is_broken_trunk,
main as trymerge_main,
MandatoryChecksMissingError,
MergeRule,
@ -544,102 +543,6 @@ class TestTryMerge(TestCase):
for case in test_cases:
self.assertEqual(case["expected"], remove_job_name_suffix(case["name"]))
def test_is_broken_trunk(self, *args: Any) -> None:
test_cases: List[Dict[str, Any]] = [
{
"head_job": None,
"base_jobs": {
"job_a": {
"conclusion": "success",
"failure_captures": ["a", "b"],
},
"job_b": {
"conclusion": "failure",
"failure_captures": ["a", "b"],
},
},
"expected": False,
"description": "Invalid input - head job",
},
{
"head_job": {
"conclusion": "failure",
"failure_captures": ["a", "b"],
},
"base_jobs": None,
"expected": False,
"description": "Invalid input - base jobs",
},
{
"head_job": {
"conclusion": "failure",
"failure_captures": ["a", "b"],
},
"base_jobs": {},
"expected": False,
"description": "Invalid input - empty base jobs",
},
{
"head_job": {
"conclusion": "failure",
"failure_captures": ["x", "y"],
},
"base_jobs": {
"job_a": {
"conclusion": "success",
"failure_captures": ["a", "b"],
},
"job_b": {
"conclusion": "failure",
"failure_captures": ["x", "y"],
},
},
"expected": True,
"description": "Found a match",
},
{
"head_job": {
"conclusion": "success",
"failure_captures": ["x", "y"],
},
"base_jobs": {
"job_a": {
"conclusion": "success",
"failure_captures": ["a", "b"],
},
"job_b": {
"conclusion": "failure",
"failure_captures": ["x", "y"],
},
},
"expected": False,
"description": "Not found - different conclusion",
},
{
"head_job": {
"conclusion": "failure",
"failure_captures": ["a", "b"],
},
"base_jobs": {
"job_a": {
"conclusion": "success",
"failure_captures": ["a", "b"],
},
"job_b": {
"conclusion": "failure",
"failure_captures": ["x", "y"],
},
},
"expected": False,
"description": "Not found - different captured failures",
},
]
for case in test_cases:
self.assertEqual(
case["expected"], is_broken_trunk(case["head_job"], case["base_jobs"])
)
def test_get_merge_base(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 104121)
@ -669,8 +572,6 @@ class TestBypassFailures(TestCase):
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
[],
)
self.assertTrue(
@ -683,13 +584,13 @@ class TestBypassFailures(TestCase):
checks[
"trunk / win-vs2019-cpu-py3 / test (default, 2, 3, windows.4xlarge.nonephemeral)"
].classification
== "BROKEN_TRUNK"
== "FLAKY"
)
self.assertTrue(
checks[
"pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge)"
].classification
== "BROKEN_TRUNK"
== "FLAKY"
)
self.assertTrue(
checks[
@ -704,15 +605,15 @@ class TestBypassFailures(TestCase):
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 2)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 4)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
# Not set any threshold, defaults to -1 to ignore all flaky and broken trunk failures
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 2)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 4)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
# Set the threshold lower than the number of ok failures
pending, failed, ignorable = categorize_checks(
@ -720,8 +621,8 @@ class TestBypassFailures(TestCase):
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 6)
self.assertTrue(len(ignorable["FLAKY"]) == 2)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 4)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
# Set the threshold to 0 like when ignore_flaky_failures is on
pending, failed, ignorable = categorize_checks(
@ -729,8 +630,8 @@ class TestBypassFailures(TestCase):
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 6)
self.assertTrue(len(ignorable["FLAKY"]) == 2)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 4)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
def test_get_classifications_flaky_fullname(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 110362)
@ -739,8 +640,6 @@ class TestBypassFailures(TestCase):
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
@ -755,8 +654,6 @@ class TestBypassFailures(TestCase):
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
@ -773,8 +670,6 @@ class TestBypassFailures(TestCase):
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
@ -789,8 +684,6 @@ class TestBypassFailures(TestCase):
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
[],
)
workflow_name = "linux-bionic-cuda12.1-py3.10-gcc9-bazel-test"
@ -812,13 +705,13 @@ class TestBypassFailures(TestCase):
# This PR had one broken trunk failure but it was run on a different shard
# than the one on the base commit. This should still count as broken trunk
"pr_num": 104214,
"mock_merge_base": "436d035dc74db9c703297a62163b0cad0c546665",
"related_failure_count": 0,
"unrelated_failure_count": 1,
},
{
# This PR had one broken trunk failure and it used ghstack
"pr_num": 105145,
"mock_merge_base": "194fe1d12f9860734cc28ed21bdabda2fbb06336",
"related_failure_count": 0,
"unrelated_failure_count": 1,
},
{
@ -827,41 +720,44 @@ class TestBypassFailures(TestCase):
# keep the failure record from the merge base so that it can
# be used to detect broken trunk
"pr_num": 107160,
"mock_merge_base": "a5d841ef01e615e2a654fb12cf0cd08697d12ccf",
"related_failure_count": 0,
"unrelated_failure_count": 4,
},
{
# This PR used Dr.CI broken trunk classification
"pr_num": 111253,
"related_failure_count": 1,
"unrelated_failure_count": 2,
},
]
for case in test_cases:
pr_num = case["pr_num"]
mock_merge_base = case["mock_merge_base"]
related_failure_count = case["related_failure_count"]
unrelated_failure_count = case["unrelated_failure_count"]
pr = GitHubPR("pytorch", "pytorch", cast(int, pr_num))
with mock.patch(
"trymerge.gh_fetch_merge_base", return_value=mock_merge_base
) as mocked_gh_fetch_merge_base:
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
[],
)
pr = GitHubPR("pytorch", "pytorch", pr_num)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, _ = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
pending, failed, _ = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == related_failure_count)
# When the ok_failed_checks_threshold is set to 0, the broken trunk failure
# won't be ignored
pending, failed, _ = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=0
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == unrelated_failure_count)
# When the ok_failed_checks_threshold is set to 0, the broken trunk failure
# won't be ignored
pending, failed, _ = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=0
)
self.assertTrue(len(pending) == 0)
self.assertTrue(
len(failed) == unrelated_failure_count + related_failure_count
)
def test_ignore_current(self, *args: Any) -> None:
# Test various interactions of the failure classifier to ensure that ignore
@ -876,24 +772,6 @@ class TestBypassFailures(TestCase):
pr = GitHubPR("pytorch", "pytorch", 109584)
checks = pr.get_checkrun_conclusions()
# No broken trunk or flaky as the merge base is not set, these failures are
# counted as ignore current when ic is used
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
None,
[broken_trunk, flaky],
)
self.assertTrue(checks[flaky].classification == "IGNORE_CURRENT_CHECK")
self.assertTrue(checks[broken_trunk].classification == "IGNORE_CURRENT_CHECK")
_, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(failed) == 4)
self.assertTrue(len(ignorable["IGNORE_CURRENT_CHECK"]) == 2)
self.assertTrue(len(ignorable["FLAKY"]) == 0)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
# Known flaky failure takes precedence over ignore current (need to set the
# merge base here to get the results from Rockset, and that categorize the
# broken trunk failure too
@ -901,8 +779,6 @@ class TestBypassFailures(TestCase):
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
[broken_trunk, flaky],
)
self.assertTrue(checks[flaky].classification == "FLAKY")
@ -910,8 +786,8 @@ class TestBypassFailures(TestCase):
_, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["IGNORE_CURRENT_CHECK"]) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 2)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 4)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
@mock.patch("trymerge.read_merge_rules", side_effect=xla_merge_rules)
def test_dont_ignore_flaky_failures(self, *args: Any) -> None:

View File

@ -1271,22 +1271,12 @@ def find_matching_merge_rule(
if not rules:
reject_reason = f"Rejecting the merge as no rules are defined for the repository in {MERGE_RULE_PATH}"
raise RuntimeError(reject_reason)
checks = pr.get_checkrun_conclusions()
base_rev = None
try:
# is allowed to fail if git is not available
base_rev = pr.get_merge_base()
except Exception as e:
print(
f"Failed fetching base git revision for {pr.pr_num}. Skipping additional classifications.\n"
f"{type(e)}\n{e}"
)
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
base_rev,
ignore_current_checks=ignore_current_checks,
)
@ -1569,34 +1559,33 @@ def remove_job_name_suffix(name: str, replacement: str = ")") -> str:
def is_broken_trunk(
head_job: Optional[Dict[str, Any]], base_jobs: Optional[Dict[str, Dict[str, Any]]]
name: str,
drci_classifications: Any,
) -> bool:
if not head_job or not base_jobs:
if not name or not drci_classifications:
return False
# Consult the list of broken trunk failures from Dr.CI
return any(
head_job["conclusion"] == base_job["conclusion"]
and head_job["failure_captures"] == base_job["failure_captures"]
for base_job in base_jobs.values()
name == broken_trunk["name"]
for broken_trunk in drci_classifications.get("BROKEN_TRUNK", [])
)
def is_flaky(
head_job: Optional[Dict[str, Any]],
name: str,
drci_classifications: Any,
) -> bool:
if not head_job or not drci_classifications:
if not name or not drci_classifications:
return False
# Consult the list of flaky failures from Dr.CI
return any(
head_job.get("full_name", "") == flaky["name"]
for flaky in drci_classifications.get("FLAKY", [])
)
return any(name == flaky["name"] for flaky in drci_classifications.get("FLAKY", []))
def is_invalid_cancel(
head_job: Optional[Dict[str, Any]],
name: str,
conclusion: Optional[str],
drci_classifications: Any,
) -> bool:
"""
@ -1604,18 +1593,18 @@ def is_invalid_cancel(
signals have been removed from HUD and Dr.CI. The same needs to be done
here for consistency
"""
if not head_job or not drci_classifications:
return False
full_name = head_job.get("full_name", "")
if head_job.get("conclusion", "") != "cancelled" or not full_name:
if (
not name
or not drci_classifications
or not conclusion
or conclusion.upper() != "CANCELLED"
):
return False
# If a job is cancelled and not listed as a failure by Dr.CI, it's an
# invalid signal and can be ignored
return all(
full_name != failure["name"]
for failure in drci_classifications.get("FAILED", [])
name != failure["name"] for failure in drci_classifications.get("FAILED", [])
)
@ -1623,62 +1612,8 @@ def get_classifications(
pr_num: int,
project: str,
checks: Dict[str, JobCheckState],
head_sha: str,
merge_base: Optional[str],
ignore_current_checks: Optional[List[str]],
) -> Dict[str, JobCheckState]:
# Group by job name without shard id and suffix to correctly identify broken
# trunk failures, i.e. linux-bionic-cuda12.1-py3.10-gcc9-sm86 / test (default)
head_sha_jobs: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)
merge_base_jobs: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)
if merge_base is not None:
def insert(
d: Dict[str, Dict[str, Dict[str, Any]]],
key: str,
val: Dict[str, Any],
overwrite_failed_run_attempt: bool,
) -> None:
key_no_suffix = remove_job_name_suffix(key)
if key not in d[key_no_suffix]:
d[key_no_suffix][key] = val
return
# When overwrite_failed_run_attempt is set to True, always overwrite
# the job with the result from the latest attempt. This option is for
# jobs from the pull request head_sha where the latest retry is used
# when merging
#
# When overwrite_failed_run_attempt is False, only overwrite the job
# with the result from the latest attempt if the latest retry failed.
# This option is for jobs from the merger_base where we want to record
# failures for broken trunk
if d[key_no_suffix][key]["id"] < val["id"] and (
overwrite_failed_run_attempt or not is_passing_status(val["conclusion"])
):
d[key_no_suffix][key] = val
rockset_results = get_rockset_results(head_sha, merge_base)
for rockset_result in rockset_results:
name = f"{rockset_result['workflow_name']} / {rockset_result['name']}"
rockset_result["full_name"] = name
if rockset_result["head_sha"] == head_sha:
insert(
head_sha_jobs,
name,
rockset_result,
overwrite_failed_run_attempt=True,
)
else:
insert(
merge_base_jobs,
name,
rockset_result,
overwrite_failed_run_attempt=False,
)
# Get the failure classification from Dr.CI, which is the source of truth
# going forward
drci_classifications = get_drci_classifications(pr_num=pr_num, project=project)
@ -1686,7 +1621,7 @@ def get_classifications(
checks_with_classifications = checks.copy()
for name, check in checks.items():
if check.status == "SUCCESS":
if check.status == "SUCCESS" or check.status == "NEUTRAL":
continue
if "unstable" in name:
@ -1700,13 +1635,9 @@ def get_classifications(
)
continue
name_no_suffix = remove_job_name_suffix(name)
head_sha_job = head_sha_jobs.get(name_no_suffix, {}).get(name)
# NB: It's important to note that when it comes to ghstack and broken trunk classification,
# the current implementation of trymerge uses the base of the current PR in the stack, i.e.
# gh/user/xx/base, while Dr.CI uses the base of the whole stack. Both works though
if is_broken_trunk(head_sha_job, merge_base_jobs.get(name_no_suffix)):
# Dr.CI uses the base of the whole stack
if is_broken_trunk(name, drci_classifications):
checks_with_classifications[name] = JobCheckState(
check.name,
check.url,
@ -1717,12 +1648,13 @@ def get_classifications(
)
continue
elif is_flaky(head_sha_job, drci_classifications):
elif is_flaky(name, drci_classifications):
checks_with_classifications[name] = JobCheckState(
check.name, check.url, check.status, "FLAKY", check.job_id, check.title
)
continue
elif is_invalid_cancel(head_sha_job, drci_classifications):
elif is_invalid_cancel(name, check.status, drci_classifications):
# NB: Create a new category here for invalid cancelled signals because
# there are usually many of them when they happen. So, they shouldn't
# be counted toward ignorable failures threshold
@ -2059,8 +1991,6 @@ def merge(
pr.pr_num,
pr.project,
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
ignore_current_checks=ignore_current_checks,
)
pending, failing, _ = categorize_checks(