mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	This is a bit weird, but author_login is not a unique field, but author_url is. Explicitly allow https://github.com/apps/pytorch-auto-revert to issue revert commands Update mocks by running ``` sed -i -e s/8e262b0495bd934d39dda198d4c09144311c5ddd6cca6a227194bd48dbfe7201/47860a8f57a214a426d1150c29893cbc2aa49507f12b731483b1a1254bca3428/ gql_mocks.json ``` Test plan: Run ```python from trymerge import GitHubPR pr=GitHubPR("pytorch", "pytorch", 164660) print(pr.get_last_comment().author_url, pr.get_comment_by_id(3375785595).author_url) ``` that should produce ``` https://github.com/pytorch-auto-revert https://github.com/apps/pytorch-auto-revert ``` Plus added a regression test that checks two particular comments for revert validity `pytorch-auto-revert` user is my alter ego :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164911 Approved by: https://github.com/jeanschmidt
		
			
				
	
	
		
			1335 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			1335 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/env python3
 | 
						|
# Tests implemented in this file are relying on GitHub GraphQL APIs
 | 
						|
# In order to avoid test flakiness, results of the queries
 | 
						|
# are cached in gql_mocks.json
 | 
						|
# PyTorch Lint workflow does not have GITHUB_TOKEN defined to avoid
 | 
						|
# flakiness, so if you are making changes to merge_rules or
 | 
						|
# GraphQL queries in trymerge.py, please make sure to delete `gql_mocks.json`
 | 
						|
# And re-run the test locally with ones PAT
 | 
						|
 | 
						|
import gzip
 | 
						|
import json
 | 
						|
import os
 | 
						|
import warnings
 | 
						|
from hashlib import sha256
 | 
						|
from typing import Any, Optional
 | 
						|
from unittest import main, mock, skip, TestCase
 | 
						|
from urllib.error import HTTPError
 | 
						|
 | 
						|
from github_utils import gh_graphql
 | 
						|
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
 | 
						|
from trymerge import (
 | 
						|
    _revlist_to_prs,
 | 
						|
    categorize_checks,
 | 
						|
    DRCI_CHECKRUN_NAME,
 | 
						|
    find_matching_merge_rule,
 | 
						|
    get_classifications,
 | 
						|
    get_drci_classifications,
 | 
						|
    gh_get_team_members,
 | 
						|
    GitHubPR,
 | 
						|
    iter_issue_timeline_until_comment,
 | 
						|
    JobCheckState,
 | 
						|
    main as trymerge_main,
 | 
						|
    MandatoryChecksMissingError,
 | 
						|
    MergeRule,
 | 
						|
    PostCommentError,
 | 
						|
    RE_GHSTACK_DESC,
 | 
						|
    read_merge_rules,
 | 
						|
    remove_job_name_suffix,
 | 
						|
    sha_from_committed_event,
 | 
						|
    sha_from_force_push_after,
 | 
						|
    validate_revert,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
if "GIT_REMOTE_URL" not in os.environ:
 | 
						|
    os.environ["GIT_REMOTE_URL"] = "https://github.com/pytorch/pytorch"
 | 
						|
 | 
						|
GQL_MOCKS = "gql_mocks.json.gz"
 | 
						|
DRCI_MOCKS = "drci_mocks.json.gz"
 | 
						|
 | 
						|
 | 
						|
def mock_query(
 | 
						|
    fallback_function: Any,
 | 
						|
    file_name: str,
 | 
						|
    key_function: Any,
 | 
						|
    *args: Any,
 | 
						|
) -> Any:
 | 
						|
    gql_db_fname = os.path.join(os.path.dirname(__file__), file_name)
 | 
						|
 | 
						|
    def get_mocked_queries() -> Any:
 | 
						|
        if not os.path.exists(gql_db_fname):
 | 
						|
            return {}
 | 
						|
        with gzip.open(gql_db_fname, encoding="utf-8", mode="rt") as f:
 | 
						|
            return json.load(f)
 | 
						|
 | 
						|
    def save_mocked_queries(obj: Any) -> None:
 | 
						|
        with gzip.open(gql_db_fname, encoding="utf-8", mode="wt") as f:
 | 
						|
            json.dump(obj, f, indent=2)
 | 
						|
            f.write("\n")
 | 
						|
 | 
						|
    key = key_function(*args)
 | 
						|
    mocked_queries = get_mocked_queries()
 | 
						|
 | 
						|
    if key in mocked_queries:
 | 
						|
        return mocked_queries[key]
 | 
						|
 | 
						|
    # TODO: Remove me once https://github.com/pytorch/pytorch/issues/160489 is resolved
 | 
						|
    raise ValueError(f"Key {key} could not be found in gql_mocks")
 | 
						|
 | 
						|
    try:
 | 
						|
        rc = fallback_function(*args)
 | 
						|
    except HTTPError as err:
 | 
						|
        if err.code == 401 or err.code == 403:
 | 
						|
            err_msg = f"If you are seeing this message during workflow run, please make sure to update {file_name}"
 | 
						|
            err_msg += f" locally, by deleting it and running {os.path.basename(__file__)} with"
 | 
						|
            err_msg += " GitHub Personal Access Token passed via GITHUB_TOKEN"
 | 
						|
            err_msg += " and drci api key passed via DRCI_BOT_KEY environment variables"
 | 
						|
            if os.getenv("GITHUB_TOKEN") is None or os.getenv("DRCI_BOT_KEY") is None:
 | 
						|
                err_msg = (
 | 
						|
                    "Failed to update cached queries as GITHUB_TOKEN or DRCI_BOT_KEY "
 | 
						|
                    + "is not defined. "
 | 
						|
                    + err_msg
 | 
						|
                )
 | 
						|
            raise RuntimeError(err_msg) from err
 | 
						|
    mocked_queries[key] = rc
 | 
						|
 | 
						|
    save_mocked_queries(mocked_queries)
 | 
						|
 | 
						|
    return rc
 | 
						|
 | 
						|
 | 
						|
def mocked_gh_graphql(query: str, **kwargs: Any) -> Any:
 | 
						|
    def key_function(query: str, kwargs: Any) -> str:
 | 
						|
        return f"query_sha={sha256(query.encode('utf-8')).hexdigest()} " + " ".join(
 | 
						|
            [f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]
 | 
						|
        )
 | 
						|
 | 
						|
    def gh_graphql_wrapper(query: str, kwargs: Any) -> Any:
 | 
						|
        return gh_graphql(query, **kwargs)
 | 
						|
 | 
						|
    return mock_query(gh_graphql_wrapper, GQL_MOCKS, key_function, query, kwargs)
 | 
						|
 | 
						|
 | 
						|
def mocked_drci_classifications(pr_num: int, project: str, num_retries: int = 3) -> Any:
 | 
						|
    return mock_query(
 | 
						|
        get_drci_classifications,
 | 
						|
        DRCI_MOCKS,
 | 
						|
        lambda x, y: f"{x} {y}",
 | 
						|
        pr_num,
 | 
						|
        project,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
 | 
						|
    class Object:
 | 
						|
        def __init__(self) -> None:
 | 
						|
            self.revert = revert
 | 
						|
            self.force = force
 | 
						|
            self.pr_num = 76123
 | 
						|
            self.dry_run = True
 | 
						|
            self.comment_id = 12345  # Set to non-zero value
 | 
						|
            self.reason = "this is for testing"
 | 
						|
            self.ignore_current = False
 | 
						|
            self.check_mergeability = False
 | 
						|
 | 
						|
    return Object()
 | 
						|
 | 
						|
 | 
						|
def mock_remove_label(
 | 
						|
    org: str, repo: str, pr_num: str, label: str, dry_run: bool
 | 
						|
) -> None:
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
def mock_revert(
 | 
						|
    repo: GitRepo,
 | 
						|
    pr: GitHubPR,
 | 
						|
    *,
 | 
						|
    dry_run: bool = False,
 | 
						|
    comment_id: Optional[int] = None,
 | 
						|
    reason: Optional[str] = None,
 | 
						|
) -> None:
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
def mock_merge(
 | 
						|
    pr: GitHubPR,
 | 
						|
    repo: GitRepo,
 | 
						|
    comment_id: int,
 | 
						|
    dry_run: bool = False,
 | 
						|
    skip_mandatory_checks: bool = False,
 | 
						|
    timeout_minutes: int = 400,
 | 
						|
    stale_pr_days: int = 3,
 | 
						|
    ignore_current: bool = False,
 | 
						|
) -> None:
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
def mock_gh_get_info() -> Any:
 | 
						|
    return {
 | 
						|
        "closed": False,
 | 
						|
        "isCrossRepository": False,
 | 
						|
        "headRefName": "foo",
 | 
						|
        "baseRefName": "bar",
 | 
						|
        "baseRepository": {"defaultBranchRef": {"name": "bar"}},
 | 
						|
        "files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
 | 
						|
        "changedFiles": 0,
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> list[MergeRule]:
 | 
						|
    return [
 | 
						|
        MergeRule(
 | 
						|
            name="mock with nonexistent check",
 | 
						|
            patterns=["*"],
 | 
						|
            approved_by=[],
 | 
						|
            mandatory_checks_name=["Lint", "Facebook CLA Check", "nonexistent"],
 | 
						|
            ignore_flaky_failures=True,
 | 
						|
        ),
 | 
						|
    ]
 | 
						|
 | 
						|
 | 
						|
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
 | 
						|
    return [
 | 
						|
        MergeRule(
 | 
						|
            name="super",
 | 
						|
            patterns=["*"],
 | 
						|
            approved_by=["pytorch/metamates", "ngimel"],
 | 
						|
            mandatory_checks_name=[
 | 
						|
                "Lint",
 | 
						|
                "pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
 | 
						|
            ],
 | 
						|
            ignore_flaky_failures=True,
 | 
						|
        ),
 | 
						|
        MergeRule(
 | 
						|
            name="xla",
 | 
						|
            patterns=[".github/ci_commit_pins/xla.txt"],
 | 
						|
            approved_by=["pytorchbot"],
 | 
						|
            mandatory_checks_name=[
 | 
						|
                "Lint",
 | 
						|
                "EasyCLA",
 | 
						|
                "pull / linux-focal-py3_8-clang9-xla / build",
 | 
						|
                "pull / linux-focal-py3_8-clang9-xla / test (xla, 1, 1, linux.12xlarge)",
 | 
						|
            ],
 | 
						|
            ignore_flaky_failures=True,
 | 
						|
        ),
 | 
						|
    ]
 | 
						|
 | 
						|
 | 
						|
def mocked_read_merge_rules_approvers(
 | 
						|
    repo: Any, org: str, project: str
 | 
						|
) -> list[MergeRule]:
 | 
						|
    return [
 | 
						|
        MergeRule(
 | 
						|
            name="Core Reviewers",
 | 
						|
            patterns=["*"],
 | 
						|
            approved_by=["1", "2", "3", "4", "5", "6"],
 | 
						|
            mandatory_checks_name=[
 | 
						|
                "Lint",
 | 
						|
                "pull",
 | 
						|
            ],
 | 
						|
        ),
 | 
						|
        MergeRule(
 | 
						|
            name="Core Maintainers",
 | 
						|
            patterns=["*"],
 | 
						|
            approved_by=["1", "2", "malfet"],
 | 
						|
            mandatory_checks_name=[
 | 
						|
                "Lint",
 | 
						|
                "pull",
 | 
						|
            ],
 | 
						|
        ),
 | 
						|
    ]
 | 
						|
 | 
						|
 | 
						|
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> list[MergeRule]:
 | 
						|
    raise RuntimeError("testing")
 | 
						|
 | 
						|
 | 
						|
def xla_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
 | 
						|
    return [
 | 
						|
        MergeRule(
 | 
						|
            name=" OSS CI / pytorchbot / XLA",
 | 
						|
            patterns=[".github/ci_commit_pins/xla.txt"],
 | 
						|
            approved_by=["pytorchbot"],
 | 
						|
            mandatory_checks_name=[
 | 
						|
                "Lint",
 | 
						|
                "EasyCLA",
 | 
						|
                "pull / linux-bionic-py3_8-clang8-xla / build",
 | 
						|
                "pull / linux-bionic-py3_8-clang8-xla / test (xla, 1, 1, linux.4xlarge)",
 | 
						|
                "inductor / cuda11.8-py3.10-gcc7-sm86 / test (inductor_torchbench_dynamic, 1, 1, linux.g5.4xlarge.nvidia.gpu)",
 | 
						|
            ],
 | 
						|
            ignore_flaky_failures=False,
 | 
						|
        ),
 | 
						|
    ]
 | 
						|
 | 
						|
 | 
						|
class DummyGitRepo(GitRepo):
 | 
						|
    def __init__(self) -> None:
 | 
						|
        super().__init__(get_git_repo_dir(), get_git_remote_name())
 | 
						|
 | 
						|
    def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
 | 
						|
        return ["FakeCommitSha"]
 | 
						|
 | 
						|
    def commit_message(self, ref: str) -> str:
 | 
						|
        return "super awesome commit message"
 | 
						|
 | 
						|
 | 
						|
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | 
						|
@mock.patch(
 | 
						|
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | 
						|
)
 | 
						|
class TestTryMerge(TestCase):
 | 
						|
    def test_merge_rules_valid(self, *args: Any) -> None:
 | 
						|
        "Test that merge_rules.yaml can be parsed"
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        merge_rules = read_merge_rules(repo, "pytorch", "pytorch")
 | 
						|
        self.assertGreater(len(merge_rules), 1)
 | 
						|
 | 
						|
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
 | 
						|
    def test_match_rules(self, *args: Any) -> None:
 | 
						|
        "Tests that PR passes merge rules"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 109999)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
 | 
						|
 | 
						|
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_raise)
 | 
						|
    def test_read_merge_rules_fails(self, *args: Any) -> None:
 | 
						|
        "Tests that PR fails to read the merge rules"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 77700)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        self.assertRaisesRegex(
 | 
						|
            RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo)
 | 
						|
        )
 | 
						|
 | 
						|
    @mock.patch(
 | 
						|
        "trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_approvers
 | 
						|
    )
 | 
						|
    def test_match_rules_approvers(self, *args: Any) -> None:
 | 
						|
        "Tests that PR has the necessary approvers"
 | 
						|
        repo = DummyGitRepo()
 | 
						|
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 115329)
 | 
						|
        # Test that all potential approvers across all rules are listed if the
 | 
						|
        # PR doesn't have one of them
 | 
						|
        for mock_rule in ["Core Reviewers", "Core Maintainers"]:
 | 
						|
            self.assertRaisesRegex(
 | 
						|
                RuntimeError,
 | 
						|
                mock_rule,
 | 
						|
                lambda: find_matching_merge_rule(pr, repo),
 | 
						|
            )
 | 
						|
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 115495)
 | 
						|
        # Test that PR with the correct approvers doesn't raise any exception
 | 
						|
        self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
 | 
						|
 | 
						|
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
 | 
						|
    def test_lint_fails(self, *args: Any) -> None:
 | 
						|
        "Tests that PR fails mandatory lint check"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 90791)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        self.assertRaises(RuntimeError, lambda: find_matching_merge_rule(pr, repo))
 | 
						|
 | 
						|
    def test_get_last_comment(self, *args: Any) -> None:
 | 
						|
        "Tests that last comment can be fetched"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 71759)
 | 
						|
        comment = pr.get_last_comment()
 | 
						|
        self.assertEqual(comment.author_login, "github-actions")
 | 
						|
        self.assertIsNone(comment.editor_login)
 | 
						|
        self.assertTrue("You've committed this PR" in comment.body_text)
 | 
						|
 | 
						|
    def test_get_author_null(self, *args: Any) -> None:
 | 
						|
        """Tests that PR author can be computed
 | 
						|
        If reply contains NULL
 | 
						|
        """
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 71759)
 | 
						|
        author = pr.get_author()
 | 
						|
        self.assertTrue(author is not None)
 | 
						|
        self.assertTrue("@" in author)
 | 
						|
        self.assertTrue(pr.get_diff_revision() is None)
 | 
						|
 | 
						|
        # PR with multiple contributors, but creator id is not among authors
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 75095)
 | 
						|
        self.assertEqual(pr.get_pr_creator_login(), "mruberry")
 | 
						|
        author = pr.get_author()
 | 
						|
        self.assertTrue(author is not None)
 | 
						|
 | 
						|
    def test_large_diff(self, *args: Any) -> None:
 | 
						|
        "Tests that PR with 100+ files can be fetched"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 73099)
 | 
						|
        self.assertTrue(pr.get_changed_files_count() > 100)
 | 
						|
        flist = pr.get_changed_files()
 | 
						|
        self.assertEqual(len(flist), pr.get_changed_files_count())
 | 
						|
 | 
						|
    def test_internal_changes(self, *args: Any) -> None:
 | 
						|
        "Tests that PR with internal changes is detected"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 110140)
 | 
						|
        self.assertTrue(pr.has_internal_changes())
 | 
						|
 | 
						|
    def test_comments_pagination(self, *args: Any) -> None:
 | 
						|
        "Tests that PR with 50+ comments can be fetched"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 31093)
 | 
						|
        self.assertGreater(len(pr.get_comments()), 50)
 | 
						|
 | 
						|
    def test_gql_complexity(self, *args: Any) -> None:
 | 
						|
        "Fetch comments and conclusions for PR with 60 commits"
 | 
						|
        # Previous version of GrapQL query used to cause HTTP/502 error
 | 
						|
        # see https://gist.github.com/malfet/9b93bc7eeddeaf1d84546efc4f0c577f
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 68111)
 | 
						|
        self.assertGreater(len(pr.get_comments()), 20)
 | 
						|
        # NS(09/27/2023): GitHub seems to recycle older checkruns
 | 
						|
        # https://github.com/pytorch/pytorch/pull/68111/checks shows 0 runs
 | 
						|
        # self.assertGreater(len(pr.get_checkrun_conclusions()), 3)
 | 
						|
        self.assertGreater(pr.get_commit_count(), 60)
 | 
						|
 | 
						|
    @skip("GitHub doesn't keep this data anymore")
 | 
						|
    def test_gql_retrieve_checksuites(self, *args: Any) -> None:
 | 
						|
        "Fetch comments and conclusions for PR with 60 commits"
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 94787)
 | 
						|
        self.assertEqual(len(pr.get_checkrun_conclusions()), 182)
 | 
						|
 | 
						|
    def test_team_members(self, *args: Any) -> None:
 | 
						|
        "Test fetching team members works"
 | 
						|
        dev_infra_team = gh_get_team_members("pytorch", "pytorch-dev-infra")
 | 
						|
        self.assertGreater(len(dev_infra_team), 2)
 | 
						|
        with self.assertWarns(Warning):
 | 
						|
            non_existing_team = gh_get_team_members("pytorch", "qwertyuiop")
 | 
						|
            self.assertEqual(len(non_existing_team), 0)
 | 
						|
 | 
						|
    def test_get_author_many_commits(self, *args: Any) -> None:
 | 
						|
        """Tests that authors for all commits can be fetched"""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 76118)
 | 
						|
        authors = pr.get_authors()
 | 
						|
        self.assertGreater(pr.get_commit_count(), 100)
 | 
						|
        self.assertGreater(len(authors), 50)
 | 
						|
        self.assertTrue("@" in pr.get_author())
 | 
						|
 | 
						|
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_NE)
 | 
						|
    def test_pending_status_check(self, *args: Any) -> None:
 | 
						|
        """Tests that PR with nonexistent/pending status checks fails with the right reason."""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 76118)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        self.assertRaisesRegex(
 | 
						|
            MandatoryChecksMissingError,
 | 
						|
            ".*are pending/not yet run.*",
 | 
						|
            lambda: find_matching_merge_rule(pr, repo),
 | 
						|
        )
 | 
						|
 | 
						|
    def test_get_author_many_reviews(self, *args: Any) -> None:
 | 
						|
        """Tests that all reviews can be fetched"""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 76123)
 | 
						|
        approved_by = pr.get_approved_by()
 | 
						|
        self.assertGreater(len(approved_by), 0)
 | 
						|
        assert pr._reviews is not None  # to pacify mypy
 | 
						|
        self.assertGreater(len(pr._reviews), 100)
 | 
						|
 | 
						|
    def get_co_authors(self, *args: Any) -> None:
 | 
						|
        """Tests that co-authors are recognized"""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 118347)
 | 
						|
        authors = pr.get_authors()
 | 
						|
        self.assertIn("kit1980", authors)
 | 
						|
        self.assertIn("Co-authored-by:", pr.gen_commit_message())
 | 
						|
 | 
						|
    def test_get_checkruns_many_runs(self, *args: Any) -> None:
 | 
						|
        """Tests that all checkruns can be fetched"""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 105260)
 | 
						|
        conclusions = pr.get_checkrun_conclusions()
 | 
						|
        self.assertEqual(len(conclusions), 221)
 | 
						|
        self.assertTrue(
 | 
						|
            "pull / linux-docs / build-docs-cpp-false" in conclusions.keys()
 | 
						|
        )
 | 
						|
 | 
						|
    def test_cancelled_gets_ignored(self, *args: Any) -> None:
 | 
						|
        """Tests that cancelled workflow does not override existing successful status"""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 110367)
 | 
						|
        conclusions = pr.get_checkrun_conclusions()
 | 
						|
        lint_checks = [name for name in conclusions.keys() if "Lint" in name]
 | 
						|
        self.assertTrue(len(lint_checks) > 0)
 | 
						|
        self.assertTrue(
 | 
						|
            all(conclusions[name].status == "SUCCESS" for name in lint_checks)
 | 
						|
        )
 | 
						|
 | 
						|
    def test_get_review_comment_by_id(self, *args: Any) -> None:
 | 
						|
        """Tests that even if the comment requested was actually a review instead of a simple comment, we can still find it"""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 107070)
 | 
						|
        review_comment_id = 1582767635
 | 
						|
        comment = pr.get_comment_by_id(review_comment_id)
 | 
						|
        self.assertIsNotNone(comment)
 | 
						|
 | 
						|
    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
 | 
						|
    @mock.patch("trymerge.parse_args", return_value=mock_parse_args(True, False))
 | 
						|
    @mock.patch("trymerge.try_revert", side_effect=mock_revert)
 | 
						|
    def test_main_revert(self, mock_revert: Any, *args: Any) -> None:
 | 
						|
        trymerge_main()
 | 
						|
        mock_revert.assert_called_once()
 | 
						|
 | 
						|
    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
 | 
						|
    @mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, True))
 | 
						|
    @mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
 | 
						|
    @mock.patch("trymerge.merge", side_effect=mock_merge)
 | 
						|
    def test_main_force(
 | 
						|
        self, mock_merge: Any, mock_parse_args: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        trymerge_main()
 | 
						|
        mock_merge.assert_called_once_with(
 | 
						|
            mock.ANY,
 | 
						|
            mock.ANY,
 | 
						|
            comment_id=mock.ANY,
 | 
						|
            dry_run=mock.ANY,
 | 
						|
            skip_mandatory_checks=True,
 | 
						|
            ignore_current=False,
 | 
						|
        )
 | 
						|
 | 
						|
    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
 | 
						|
    @mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, False))
 | 
						|
    @mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
 | 
						|
    @mock.patch("trymerge.merge", side_effect=mock_merge)
 | 
						|
    def test_main_merge(self, mock_merge: Any, *args: Any) -> None:
 | 
						|
        trymerge_main()
 | 
						|
        mock_merge.assert_called_once_with(
 | 
						|
            mock.ANY,
 | 
						|
            mock.ANY,
 | 
						|
            comment_id=mock.ANY,
 | 
						|
            dry_run=mock.ANY,
 | 
						|
            skip_mandatory_checks=False,
 | 
						|
            ignore_current=False,
 | 
						|
        )
 | 
						|
 | 
						|
    @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
 | 
						|
    def test_revert_rules(self, *args: Any) -> None:
 | 
						|
        """Tests that reverts from collaborators are allowed"""
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 79694)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845))
 | 
						|
 | 
						|
    def test_get_changed_files(self, *args: Any) -> None:
 | 
						|
        """
 | 
						|
        Tests that the list changed files in a PR doesn't include duplicates
 | 
						|
        """
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 95233)
 | 
						|
        try:
 | 
						|
            changed_files = pr.get_changed_files()
 | 
						|
        except RuntimeError as error:
 | 
						|
            self.fail(f"get_changed_files throws an exception: {error}")
 | 
						|
 | 
						|
        self.assertEqual(len(changed_files), pr.get_changed_files_count())
 | 
						|
 | 
						|
    def test_revert_codev_abandoned_diff_succeeds(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 100652)
 | 
						|
 | 
						|
        class GitRepoCoDev(DummyGitRepo):
 | 
						|
            def commit_message(self, ref: str) -> str:
 | 
						|
                return pr.get_body()
 | 
						|
 | 
						|
        repo = GitRepoCoDev()
 | 
						|
        validate_revert(repo, pr, comment_id=1588195237)
 | 
						|
 | 
						|
    def test_pr_changed_submodule_detection(self, *args: Any) -> None:
 | 
						|
        # Updates submodule during dev-cycle but reverts it later
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 95045)
 | 
						|
        self.assertEqual(pr.get_changed_submodules(), [])
 | 
						|
        self.assertFalse(pr.has_invalid_submodule_updates())
 | 
						|
 | 
						|
        # PR updates ideep
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 94939)
 | 
						|
        self.assertEqual(pr.get_changed_submodules(), ["third_party/ideep"])
 | 
						|
        self.assertTrue(pr.has_invalid_submodule_updates())
 | 
						|
 | 
						|
        # Automated submodule update
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 91051)
 | 
						|
        self.assertEqual(pr.get_changed_submodules(), ["third_party/kineto"])
 | 
						|
        self.assertFalse(pr.has_invalid_submodule_updates())
 | 
						|
 | 
						|
    def test_remove_job_name_suffix(self, *args: Any) -> None:
 | 
						|
        test_cases = [
 | 
						|
            {
 | 
						|
                "name": "linux-bionic-cuda12.6-py3.10-gcc9-sm86 / test (default, 1, 5, linux.g5.4xlarge.nvidia.gpu)",
 | 
						|
                "expected": "linux-bionic-cuda12.6-py3.10-gcc9-sm86 / test (default)",
 | 
						|
            },
 | 
						|
            {
 | 
						|
                "name": "android-emulator-build-test / build-and-test (default, 1, 1, ubuntu-20.04-16x)",
 | 
						|
                "expected": "android-emulator-build-test / build-and-test (default)",
 | 
						|
            },
 | 
						|
            {
 | 
						|
                "name": "linux-focal-rocm5.4.2-py3.8 / build",
 | 
						|
                "expected": "linux-focal-rocm5.4.2-py3.8 / build",
 | 
						|
            },
 | 
						|
            {
 | 
						|
                "name": "libtorch-cpu-shared-with-deps-release-build",
 | 
						|
                "expected": "libtorch-cpu-shared-with-deps-release-build",
 | 
						|
            },
 | 
						|
            {
 | 
						|
                "name": "manywheel-py3_8-cuda11_8-test / test",
 | 
						|
                "expected": "manywheel-py3_8-cuda11_8-test / test",
 | 
						|
            },
 | 
						|
            {
 | 
						|
                "name": "lintrunner / linux-job",
 | 
						|
                "expected": "lintrunner / linux-job",
 | 
						|
            },
 | 
						|
            {
 | 
						|
                "name": "Test `run_test.py` is usable without boto3",
 | 
						|
                "expected": "Test `run_test.py` is usable without boto3",
 | 
						|
            },
 | 
						|
        ]
 | 
						|
 | 
						|
        for case in test_cases:
 | 
						|
            self.assertEqual(case["expected"], remove_job_name_suffix(case["name"]))
 | 
						|
 | 
						|
    def test_get_merge_base(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 104121)
 | 
						|
 | 
						|
        mock_merge_base = "mocked-sha"
 | 
						|
        with mock.patch(
 | 
						|
            "trymerge.gh_fetch_merge_base", return_value=mock_merge_base
 | 
						|
        ) as mocked_gh_fetch_merge_base:
 | 
						|
            self.assertEqual(mock_merge_base, pr.get_merge_base())
 | 
						|
 | 
						|
            # Make sure that consecutive calls will use the same merge base instead of
 | 
						|
            # making another query
 | 
						|
            self.assertEqual(mock_merge_base, pr.get_merge_base())
 | 
						|
            mocked_gh_fetch_merge_base.assert_called_once()
 | 
						|
 | 
						|
    def test_app_can_revert(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 164660)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        app_comment_id, impostor_comment_id = 3375785595, 3377647892
 | 
						|
        # Check that app can revert
 | 
						|
        self.assertIsNotNone(validate_revert(repo, pr, comment_id=app_comment_id))
 | 
						|
        # But impostor can not
 | 
						|
        self.assertRaises(
 | 
						|
            PostCommentError,
 | 
						|
            lambda: validate_revert(repo, pr, comment_id=impostor_comment_id),
 | 
						|
        )
 | 
						|
        # Despite it's name being the name of the bot
 | 
						|
        self.assertEqual(
 | 
						|
            pr.get_comment_by_id(impostor_comment_id).author_login,
 | 
						|
            "pytorch-auto-revert",
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | 
						|
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | 
						|
@mock.patch(
 | 
						|
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | 
						|
)
 | 
						|
class TestBypassFailures(TestCase):
 | 
						|
    def test_get_classifications(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 109584)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            checks[
 | 
						|
                "pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
 | 
						|
            ].classification
 | 
						|
            == "BROKEN_TRUNK"
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            checks[
 | 
						|
                "trunk / win-vs2019-cpu-py3 / test (default, 2, 3, windows.4xlarge.nonephemeral)"
 | 
						|
            ].classification
 | 
						|
            == "FLAKY"
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            checks[
 | 
						|
                "pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge)"
 | 
						|
            ].classification
 | 
						|
            == "FLAKY"
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            checks[
 | 
						|
                "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
 | 
						|
            ].classification
 | 
						|
            == "FLAKY"
 | 
						|
        )
 | 
						|
 | 
						|
        # Set the threshold larger or equal to the number of ok failures
 | 
						|
        pending, failed, ignorable = categorize_checks(
 | 
						|
            checks, list(checks.keys()), ok_failed_checks_threshold=6
 | 
						|
        )
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | 
						|
 | 
						|
        # Not set any threshold, defaults to -1 to ignore all flaky and broken trunk failures
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | 
						|
 | 
						|
        # Set the threshold lower than the number of ok failures
 | 
						|
        pending, failed, ignorable = categorize_checks(
 | 
						|
            checks, list(checks.keys()), ok_failed_checks_threshold=1
 | 
						|
        )
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 6)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | 
						|
 | 
						|
        # Set the threshold to 0 like when ignore_flaky_failures is on
 | 
						|
        pending, failed, ignorable = categorize_checks(
 | 
						|
            checks, list(checks.keys()), ok_failed_checks_threshold=1
 | 
						|
        )
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 6)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | 
						|
 | 
						|
    def test_get_classifications_flaky_fullname(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 110362)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | 
						|
 | 
						|
    def test_get_classifications_invalid_cancel(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 110367)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 0)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
 | 
						|
        self.assertTrue(len(ignorable["UNSTABLE"]) == 3)
 | 
						|
 | 
						|
    def test_get_classifications_similar_failures(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 109750)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | 
						|
 | 
						|
    def test_get_classifications_unstable(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 104312)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        workflow_name = "linux-bionic-cuda12.1-py3.10-gcc9-bazel-test"
 | 
						|
        job_name = "build-and-test (default, 1, 1, linux.4xlarge.nvidia.gpu, unstable)"
 | 
						|
        self.assertTrue(
 | 
						|
            checks[f"pull / {workflow_name} / {job_name}"].classification == "UNSTABLE"
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(
 | 
						|
            checks, list(checks.keys()), ok_failed_checks_threshold=1
 | 
						|
        )
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
 | 
						|
 | 
						|
        # Add another test case where there is no unstable keyword in the job name, but
 | 
						|
        # the job has already been marked as unstable
 | 
						|
        pr = GitHubPR("pytorch", "executorch", 3318)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        print(checks)
 | 
						|
        workflow_name = "test-llama-app"
 | 
						|
        job_name = "mobile-job (android)"
 | 
						|
        self.assertTrue(
 | 
						|
            checks[f"Android / {workflow_name} / {job_name}"].classification
 | 
						|
            == "UNSTABLE"
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(
 | 
						|
            checks, list(checks.keys()), ok_failed_checks_threshold=1
 | 
						|
        )
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
 | 
						|
 | 
						|
    def test_get_classifications_broken_trunk(self, *args: Any) -> None:
 | 
						|
        # The mock merge base is the actual value returned by gh_fetch_merge_base
 | 
						|
        test_cases = [
 | 
						|
            {
 | 
						|
                # This PR had one broken trunk failure but it was run on a different shard
 | 
						|
                # than the one on the base commit. This should still count as broken trunk
 | 
						|
                "pr_num": 104214,
 | 
						|
                "related_failure_count": 0,
 | 
						|
                "flaky_or_broken_trunk": 1,
 | 
						|
            },
 | 
						|
            {
 | 
						|
                # This PR had one broken trunk failure and it used ghstack
 | 
						|
                "pr_num": 105145,
 | 
						|
                "related_failure_count": 0,
 | 
						|
                "flaky_or_broken_trunk": 1,
 | 
						|
            },
 | 
						|
            {
 | 
						|
                # The failure on the merge base was retried successfully and
 | 
						|
                # its conclusion changed from failure to success. We want to
 | 
						|
                # keep the failure record from the merge base so that it can
 | 
						|
                # be used to detect broken trunk
 | 
						|
                "pr_num": 107160,
 | 
						|
                "related_failure_count": 0,
 | 
						|
                "flaky_or_broken_trunk": 1,
 | 
						|
            },
 | 
						|
            {
 | 
						|
                # This PR used Dr.CI broken trunk classification
 | 
						|
                "pr_num": 111253,
 | 
						|
                "related_failure_count": 1,
 | 
						|
                "flaky_or_broken_trunk": 1,
 | 
						|
            },
 | 
						|
        ]
 | 
						|
 | 
						|
        for case in test_cases:
 | 
						|
            pr_num = case["pr_num"]
 | 
						|
            related_failure_count = case["related_failure_count"]
 | 
						|
            flaky_or_broken_trunk = case["flaky_or_broken_trunk"]
 | 
						|
 | 
						|
            pr = GitHubPR("pytorch", "pytorch", pr_num)
 | 
						|
            checks = pr.get_checkrun_conclusions()
 | 
						|
            checks = get_classifications(
 | 
						|
                pr.pr_num,
 | 
						|
                pr.project,
 | 
						|
                checks,
 | 
						|
                [],
 | 
						|
            )
 | 
						|
 | 
						|
            pending, failed, _ = categorize_checks(checks, list(checks.keys()))
 | 
						|
            self.assertTrue(len(pending) == 0)
 | 
						|
            self.assertTrue(len(failed) == related_failure_count)
 | 
						|
 | 
						|
            # When the ok_failed_checks_threshold is set to 0, the broken trunk failure
 | 
						|
            # won't be ignored
 | 
						|
            pending, failed, _ = categorize_checks(
 | 
						|
                checks, list(checks.keys()), ok_failed_checks_threshold=0
 | 
						|
            )
 | 
						|
            self.assertTrue(len(pending) == 0)
 | 
						|
            self.assertTrue(
 | 
						|
                len(failed) == flaky_or_broken_trunk + related_failure_count
 | 
						|
            )
 | 
						|
 | 
						|
    def test_ignore_current(self, *args: Any) -> None:
 | 
						|
        # Test various interactions of the failure classifier to ensure that ignore
 | 
						|
        # current checks takes place after other classifications: flaky, unstable,
 | 
						|
        # or broken trunk. Only actual new failures should be kept in the list of
 | 
						|
        # ignore current checks to use to record force merge with actual failures
 | 
						|
        flaky = "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
 | 
						|
        broken_trunk = (
 | 
						|
            "pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
 | 
						|
        )
 | 
						|
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 109584)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
 | 
						|
        # Known flaky failure takes precedence over ignore current (need to set the
 | 
						|
        # merge base here to get the results from Dr. CI, and that categorize the
 | 
						|
        # broken trunk failure too
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [broken_trunk, flaky],
 | 
						|
        )
 | 
						|
        self.assertTrue(checks[flaky].classification == "FLAKY")
 | 
						|
        self.assertTrue(checks[broken_trunk].classification == "BROKEN_TRUNK")
 | 
						|
        _, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["IGNORE_CURRENT_CHECK"]) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | 
						|
 | 
						|
    def test_get_classifications_wrong_workflow_name(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 123104)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
 | 
						|
        check_name = "linux-binary-conda / conda-py3_8-cuda11_8-build / build"
 | 
						|
        check_name_workflow_path = ".github/workflows/generated-linux-binary-conda-nightly.yml / conda-py3_8-cuda11_8-build / build"
 | 
						|
 | 
						|
        # Mock a check where the workflow name uses the full path
 | 
						|
        checks[check_name_workflow_path] = JobCheckState(
 | 
						|
            check_name_workflow_path,
 | 
						|
            checks[check_name].url,
 | 
						|
            checks[check_name].status,
 | 
						|
            checks[check_name].classification,
 | 
						|
            checks[check_name].job_id,
 | 
						|
            checks[check_name].title,
 | 
						|
            checks[check_name].summary,
 | 
						|
        )
 | 
						|
        del checks[check_name]
 | 
						|
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(
 | 
						|
            checks,
 | 
						|
            list(checks.keys()),
 | 
						|
        )
 | 
						|
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
 | 
						|
 | 
						|
    def test_ignore_failures_older_run_same_workflow(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 129013)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(
 | 
						|
            checks,
 | 
						|
            list(checks.keys()),
 | 
						|
        )
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 2)
 | 
						|
        self.assertTrue(len(ignorable["UNSTABLE"]) == 13)
 | 
						|
 | 
						|
    @mock.patch("trymerge.read_merge_rules", side_effect=xla_merge_rules)
 | 
						|
    def test_dont_ignore_flaky_failures(self, *args: Any) -> None:
 | 
						|
        """
 | 
						|
        Regression test for https://github.com/pytorch/test-infra/issues/4126
 | 
						|
        """
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 105312)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        # Check that failure is classified as flaky but still raises exception
 | 
						|
        with warnings.catch_warnings(record=True) as w, self.assertRaises(RuntimeError):
 | 
						|
            find_matching_merge_rule(pr, repo)
 | 
						|
        self.assertEqual(len(w), 1)
 | 
						|
        self.assertIn(
 | 
						|
            "1 checks failed but were likely due flakiness or broken trunk",
 | 
						|
            str(w[0].message),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | 
						|
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | 
						|
@mock.patch("trymerge.get_drci_classifications", return_value={})
 | 
						|
class TestBypassFailuresOnSandCastle(TestCase):
 | 
						|
    def test_get_classifications(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 111467)
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 0)
 | 
						|
        self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | 
						|
        self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 1)
 | 
						|
 | 
						|
    def test_get_classifications_drci_checkrun_not_found(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 111467)
 | 
						|
 | 
						|
        # No summary
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks[DRCI_CHECKRUN_NAME] = JobCheckState(
 | 
						|
            DRCI_CHECKRUN_NAME,
 | 
						|
            "",
 | 
						|
            "NEUTRAL",
 | 
						|
            None,
 | 
						|
            1,
 | 
						|
            "",
 | 
						|
            None,
 | 
						|
        )
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 2)
 | 
						|
 | 
						|
        # Empty summary
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        checks[DRCI_CHECKRUN_NAME] = JobCheckState(
 | 
						|
            DRCI_CHECKRUN_NAME,
 | 
						|
            "",
 | 
						|
            "NEUTRAL",
 | 
						|
            None,
 | 
						|
            1,
 | 
						|
            "",
 | 
						|
            "",
 | 
						|
        )
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 2)
 | 
						|
 | 
						|
        # No Dr.CI checkrun
 | 
						|
        checks = pr.get_checkrun_conclusions()
 | 
						|
        del checks[DRCI_CHECKRUN_NAME]
 | 
						|
        checks = get_classifications(
 | 
						|
            pr.pr_num,
 | 
						|
            pr.project,
 | 
						|
            checks,
 | 
						|
            [],
 | 
						|
        )
 | 
						|
        pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | 
						|
        self.assertTrue(len(pending) == 0)
 | 
						|
        self.assertTrue(len(failed) == 2)
 | 
						|
 | 
						|
 | 
						|
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | 
						|
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | 
						|
@mock.patch(
 | 
						|
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | 
						|
)
 | 
						|
class TestGitHubPRGhstackDependencies(TestCase):
 | 
						|
    def test_pr_dependencies(self, *args: Any) -> None:
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 106068)
 | 
						|
        msg = pr.gen_commit_message(filter_ghstack=True)
 | 
						|
        self.assertEqual(
 | 
						|
            msg,
 | 
						|
            f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
 | 
						|
            "Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
 | 
						|
            "Approved by: https://github.com/ezyang, https://github.com/fegin\n",
 | 
						|
        )
 | 
						|
 | 
						|
    def test_pr_dependencies_ghstack(self, *args: Any) -> None:
 | 
						|
        pr0 = GitHubPR("pytorch", "pytorch", 106032)
 | 
						|
        pr1 = GitHubPR("pytorch", "pytorch", 106033)
 | 
						|
        pr2 = GitHubPR("pytorch", "pytorch", 106034)
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 106068)
 | 
						|
        msg = pr.gen_commit_message(filter_ghstack=True, ghstack_deps=[pr0, pr1, pr2])
 | 
						|
        self.assertEqual(
 | 
						|
            msg,
 | 
						|
            f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
 | 
						|
            "Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
 | 
						|
            "Approved by: https://github.com/ezyang, https://github.com/fegin\n"
 | 
						|
            "ghstack dependencies: #106032, #106033, #106034\n",
 | 
						|
        )
 | 
						|
 | 
						|
    @skip(
 | 
						|
        reason="This test is run against a mutable PR that has changed, so it no longer works. The test should be changed"
 | 
						|
    )
 | 
						|
    @mock.patch("trymerge.read_merge_rules")
 | 
						|
    @mock.patch("trymerge.GitRepo")
 | 
						|
    @mock.patch("trymerge.get_ghstack_prs")
 | 
						|
    def test_merge_ghstack_into(
 | 
						|
        self,
 | 
						|
        mock_get_ghstack_prs: mock.MagicMock,
 | 
						|
        mock_repo: mock.MagicMock,
 | 
						|
        mock_merge_rules: mock.MagicMock,
 | 
						|
        *args: Any,
 | 
						|
    ) -> None:
 | 
						|
        """
 | 
						|
        Test that the merge_ghstack_into method works correctly
 | 
						|
        """
 | 
						|
        pr0 = GitHubPR("pytorch", "pytorch", 106032)
 | 
						|
        pr1 = GitHubPR("pytorch", "pytorch", 106033)
 | 
						|
        pr2 = GitHubPR("pytorch", "pytorch", 106034)
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 106068)
 | 
						|
 | 
						|
        # note: in reverse order (e.g. self.pr is the last commit, top of the stack)
 | 
						|
        mock_get_ghstack_prs.return_value = [
 | 
						|
            (pr0, "rev0"),
 | 
						|
            (pr1, "rev1"),
 | 
						|
            (pr2, "rev2"),
 | 
						|
            (pr, "rev123"),
 | 
						|
        ]
 | 
						|
 | 
						|
        mock_merge_rules.return_value = [
 | 
						|
            MergeRule(
 | 
						|
                "Mock title", patterns=["*"], approved_by=[], mandatory_checks_name=None
 | 
						|
            )
 | 
						|
        ]
 | 
						|
 | 
						|
        mock_repo.cherry_pick.return_value = None
 | 
						|
        mock_repo.amend_commit_message.return_value = None
 | 
						|
 | 
						|
        # Call the method under test
 | 
						|
        res = pr.merge_ghstack_into(mock_repo, True)
 | 
						|
 | 
						|
        self.assertEqual(res, [pr2, pr])
 | 
						|
 | 
						|
        mock_repo.cherry_pick.assert_any_call("rev2")
 | 
						|
        mock_repo.cherry_pick.assert_any_call("rev123")
 | 
						|
 | 
						|
        self.assertTrue(mock.call("rev1") not in mock_repo.cherry_pick.call_args_list)
 | 
						|
 | 
						|
        # Verify the first call
 | 
						|
        message = mock_repo.amend_commit_message.call_args_list[0].args[0]
 | 
						|
        prefix = (
 | 
						|
            "[FSDP] Optimize away intermediate `div_` for HSDP (#106034)\n\n\r\n"
 | 
						|
            "### Background: Gradient Pre-Divide"
 | 
						|
        )
 | 
						|
        suffix = (
 | 
						|
            "\nPull Request resolved: https://github.com/pytorch/pytorch/pull/106034\nApproved by: \nghstack "
 | 
						|
            "dependencies: #106032, #106033\n"
 | 
						|
        )
 | 
						|
 | 
						|
        self.assertTrue(message.startswith(prefix))
 | 
						|
        self.assertTrue(message.endswith(suffix))
 | 
						|
 | 
						|
        # Verify the second call
 | 
						|
        mock_repo.amend_commit_message.assert_any_call(
 | 
						|
            "[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\n"
 | 
						|
            "Differential Revision: ["
 | 
						|
            "D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\n"
 | 
						|
            "Pull Request resolved: "
 | 
						|
            "https://github.com/pytorch/pytorch/pull/106068\n"
 | 
						|
            "Approved by: \n"
 | 
						|
            "ghstack dependencies: #106032, #106033, #106034\n"
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | 
						|
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | 
						|
@mock.patch(
 | 
						|
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | 
						|
)
 | 
						|
@mock.patch.object(DummyGitRepo, "commit_message")
 | 
						|
class TestRevListToPR(TestCase):
 | 
						|
    # Tests for _revlist_to_prs function
 | 
						|
    def test__revlist_to_prs_zero_matches(
 | 
						|
        self, mock_commit_message: mock.MagicMock, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        # If zero PRs are mentioned in the commit message, it should raise an error
 | 
						|
        pr_num = 154098
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", pr_num)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        mock_commit_message.return_value = "no PRs"
 | 
						|
        self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            "PRs mentioned in commit dummy: 0.",
 | 
						|
            lambda: _revlist_to_prs(repo, pr, ["dummy"]),
 | 
						|
        )
 | 
						|
 | 
						|
    def test__revlist_to_prs_two_prs(
 | 
						|
        self, mock_commit_message: mock.MagicMock, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        # If two PRs are mentioned in the commit message, it should raise an error
 | 
						|
        pr_num = 154394
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", pr_num)
 | 
						|
        repo = DummyGitRepo()
 | 
						|
        # https://github.com/pytorch/pytorch/commit/343c56e7650f55fd030aca0b9275d6d73501d3f4
 | 
						|
 | 
						|
        commit_message = """add sticky cache pgo
 | 
						|
 | 
						|
ghstack-source-id: 9bc6dee0b427819f978bfabccb72727ba8be2f81
 | 
						|
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/154098
 | 
						|
 | 
						|
ghstack-source-id: 9bc6dee0b427819f978bfabccb72727ba8be2f81
 | 
						|
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154394"""
 | 
						|
        mock_commit_message.return_value = commit_message
 | 
						|
        self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            "PRs mentioned in commit dummy: 2.",
 | 
						|
            lambda: _revlist_to_prs(repo, pr, ["dummy"]),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | 
						|
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | 
						|
@mock.patch(
 | 
						|
    "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | 
						|
)
 | 
						|
class TestTimelineFunctions(TestCase):
 | 
						|
    """Tests for the new timeline-related functions"""
 | 
						|
 | 
						|
    def test_sha_from_committed_event(self, *args: Any) -> None:
 | 
						|
        """Test extracting SHA from committed event"""
 | 
						|
        # Based on actual GitHub API format - committed events have "sha" at top level
 | 
						|
        event = {
 | 
						|
            "event": "committed",
 | 
						|
            "sha": "fb21ce932ded6670c918804a0d9151b773770a7c",
 | 
						|
        }
 | 
						|
        self.assertEqual(
 | 
						|
            sha_from_committed_event(event), "fb21ce932ded6670c918804a0d9151b773770a7c"
 | 
						|
        )
 | 
						|
 | 
						|
        # Test with missing SHA
 | 
						|
        event_no_sha = {"event": "committed"}
 | 
						|
        self.assertIsNone(sha_from_committed_event(event_no_sha))
 | 
						|
 | 
						|
    def test_sha_from_force_push_after(self, *args: Any) -> None:
 | 
						|
        """Test extracting SHA from force push event"""
 | 
						|
        # NOTE: The current function doesn't handle the actual GitHub API format
 | 
						|
        # Real force push events have "commit_id" at top level, but this function
 | 
						|
        # looks for "after", "after_commit", "after_sha", or "head_sha" fields
 | 
						|
 | 
						|
        # Test with the legacy format the current function handles
 | 
						|
        event_legacy = {
 | 
						|
            "event": "head_ref_force_pushed",
 | 
						|
            "after": {"sha": "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e"},
 | 
						|
        }
 | 
						|
        self.assertEqual(
 | 
						|
            sha_from_force_push_after(event_legacy),
 | 
						|
            "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
 | 
						|
        )
 | 
						|
 | 
						|
        # Test with current GitHub API format (should return None with current implementation)
 | 
						|
        event_real_api = {
 | 
						|
            "event": "head_ref_force_pushed",
 | 
						|
            "commit_id": "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
 | 
						|
        }
 | 
						|
        self.assertEqual(
 | 
						|
            sha_from_force_push_after(event_real_api),
 | 
						|
            "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
 | 
						|
        )  # Current function doesn't handle commit_id
 | 
						|
 | 
						|
        # Test with missing SHA
 | 
						|
        event_no_sha = {"event": "head_ref_force_pushed"}
 | 
						|
        self.assertIsNone(sha_from_force_push_after(event_no_sha))
 | 
						|
 | 
						|
    @mock.patch("trymerge.gh_fetch_json_list")
 | 
						|
    def test_iter_issue_timeline_until_comment(
 | 
						|
        self, mock_gh_fetch_json_list: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        """Test timeline iteration until target comment"""
 | 
						|
        # Mock timeline data based on actual GitHub API format
 | 
						|
        timeline_data = [
 | 
						|
            {"event": "commented", "id": 100, "body": "first comment"},
 | 
						|
            {"event": "committed", "sha": "fb21ce932ded6670c918804a0d9151b773770a7c"},
 | 
						|
            {"event": "commented", "id": 200, "body": "target comment"},
 | 
						|
            {"event": "commented", "id": 300, "body": "after target"},
 | 
						|
        ]
 | 
						|
        mock_gh_fetch_json_list.return_value = timeline_data
 | 
						|
 | 
						|
        # Test iteration stops at target comment
 | 
						|
        events = list(iter_issue_timeline_until_comment("pytorch", "pytorch", 123, 200))
 | 
						|
        self.assertEqual(len(events), 3)  # Should stop at target comment
 | 
						|
        self.assertEqual(events[0]["event"], "commented")
 | 
						|
        self.assertEqual(events[0]["id"], 100)
 | 
						|
        self.assertEqual(events[1]["event"], "committed")
 | 
						|
        self.assertEqual(events[1]["sha"], "fb21ce932ded6670c918804a0d9151b773770a7c")
 | 
						|
        self.assertEqual(events[2]["event"], "commented")
 | 
						|
        self.assertEqual(events[2]["id"], 200)
 | 
						|
 | 
						|
    @mock.patch("trymerge.gh_fetch_json_list")
 | 
						|
    def test_iter_issue_timeline_until_comment_not_found(
 | 
						|
        self, mock_gh_fetch_json_list: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        """Test timeline iteration when target comment is not found"""
 | 
						|
        # Mock empty timeline
 | 
						|
        mock_gh_fetch_json_list.return_value = []
 | 
						|
 | 
						|
        events = list(iter_issue_timeline_until_comment("pytorch", "pytorch", 123, 999))
 | 
						|
        self.assertEqual(len(events), 0)
 | 
						|
 | 
						|
    @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | 
						|
    def test_get_commit_sha_at_comment_commit_after_comment(
 | 
						|
        self, mock_iter_timeline: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        """Test get_commit_sha_at_comment returns correct SHA after comment"""
 | 
						|
        mock_iter_timeline.return_value = [
 | 
						|
            {"event": "committed", "sha": "commit1"},
 | 
						|
            {"event": "committed", "sha": "commit2"},
 | 
						|
            {"event": "commented", "id": 100},
 | 
						|
            {"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
 | 
						|
        ]
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 77700)
 | 
						|
        sha = pr.get_commit_sha_at_comment(100)
 | 
						|
        self.assertEqual(sha, "commit2")
 | 
						|
 | 
						|
    @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | 
						|
    def test_get_commit_sha_at_comment_force_push_before_comment(
 | 
						|
        self, mock_iter_timeline: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        mock_iter_timeline.return_value = [
 | 
						|
            {"event": "committed", "sha": "commit1"},
 | 
						|
            {"event": "committed", "sha": "commit2"},
 | 
						|
            {"event": "head_ref_force_pushed", "commit_id": "commit3"},
 | 
						|
            {"event": "commented", "id": 100},
 | 
						|
        ]
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 77700)
 | 
						|
        sha = pr.get_commit_sha_at_comment(100)
 | 
						|
        self.assertEqual(sha, "commit3")
 | 
						|
 | 
						|
    @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | 
						|
    def test_get_commit_sha_at_comment_force_push_before_comment_legacy_mode(
 | 
						|
        self, mock_iter_timeline: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        mock_iter_timeline.return_value = [
 | 
						|
            {"event": "committed", "sha": "commit1"},
 | 
						|
            {"event": "committed", "sha": "commit2"},
 | 
						|
            {"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
 | 
						|
            {"event": "commented", "id": 100},
 | 
						|
        ]
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 77700)
 | 
						|
        sha = pr.get_commit_sha_at_comment(100)
 | 
						|
        self.assertEqual(sha, "commit3")
 | 
						|
 | 
						|
    @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | 
						|
    def test_get_commit_sha_at_comment_multiple_comments(
 | 
						|
        self, mock_iter_timeline: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        mock_iter_timeline.return_value = [
 | 
						|
            {"event": "committed", "sha": "commit1"},
 | 
						|
            {"event": "commented", "id": 100},
 | 
						|
            {"event": "committed", "sha": "commit2"},
 | 
						|
            {"event": "commented", "id": 200},
 | 
						|
            {"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
 | 
						|
            {"event": "commented", "id": 300},
 | 
						|
        ]
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 77700)
 | 
						|
        sha = pr.get_commit_sha_at_comment(200)
 | 
						|
        self.assertEqual(sha, "commit2")
 | 
						|
        sha = pr.get_commit_sha_at_comment(300)
 | 
						|
        self.assertEqual(sha, "commit3")
 | 
						|
 | 
						|
    @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | 
						|
    def test_get_commit_sha_at_comment_no_events(
 | 
						|
        self, mock_iter_timeline: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        mock_iter_timeline.return_value = [
 | 
						|
            {"event": "commented", "id": 100},
 | 
						|
            {"event": "labeled", "label": {"name": "test"}},
 | 
						|
        ]
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 77700)
 | 
						|
        sha = pr.get_commit_sha_at_comment(100)
 | 
						|
        self.assertIsNone(sha)
 | 
						|
 | 
						|
    @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | 
						|
    def test_get_commit_sha_at_comment_exception(
 | 
						|
        self, mock_iter_timeline: Any, *args: Any
 | 
						|
    ) -> None:
 | 
						|
        mock_iter_timeline.side_effect = Exception("API error")
 | 
						|
        pr = GitHubPR("pytorch", "pytorch", 77700)
 | 
						|
        sha = pr.get_commit_sha_at_comment(100)
 | 
						|
        self.assertIsNone(sha)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main()
 |