mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	This is a bit weird, but author_login is not a unique field, but author_url is. Explicitly allow https://github.com/apps/pytorch-auto-revert to issue revert commands Update mocks by running ``` sed -i -e s/8e262b0495bd934d39dda198d4c09144311c5ddd6cca6a227194bd48dbfe7201/47860a8f57a214a426d1150c29893cbc2aa49507f12b731483b1a1254bca3428/ gql_mocks.json ``` Test plan: Run ```python from trymerge import GitHubPR pr=GitHubPR("pytorch", "pytorch", 164660) print(pr.get_last_comment().author_url, pr.get_comment_by_id(3375785595).author_url) ``` that should produce ``` https://github.com/pytorch-auto-revert https://github.com/apps/pytorch-auto-revert ``` Plus added a regression test that checks two particular comments for revert validity `pytorch-auto-revert` user is my alter ego :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164911 Approved by: https://github.com/jeanschmidt
		
			
				
	
	
		
			1335 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			1335 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python3
 | |
| # Tests implemented in this file are relying on GitHub GraphQL APIs
 | |
| # In order to avoid test flakiness, results of the queries
 | |
| # are cached in gql_mocks.json
 | |
| # PyTorch Lint workflow does not have GITHUB_TOKEN defined to avoid
 | |
| # flakiness, so if you are making changes to merge_rules or
 | |
| # GraphQL queries in trymerge.py, please make sure to delete `gql_mocks.json`
 | |
| # And re-run the test locally with ones PAT
 | |
| 
 | |
| import gzip
 | |
| import json
 | |
| import os
 | |
| import warnings
 | |
| from hashlib import sha256
 | |
| from typing import Any, Optional
 | |
| from unittest import main, mock, skip, TestCase
 | |
| from urllib.error import HTTPError
 | |
| 
 | |
| from github_utils import gh_graphql
 | |
| from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
 | |
| from trymerge import (
 | |
|     _revlist_to_prs,
 | |
|     categorize_checks,
 | |
|     DRCI_CHECKRUN_NAME,
 | |
|     find_matching_merge_rule,
 | |
|     get_classifications,
 | |
|     get_drci_classifications,
 | |
|     gh_get_team_members,
 | |
|     GitHubPR,
 | |
|     iter_issue_timeline_until_comment,
 | |
|     JobCheckState,
 | |
|     main as trymerge_main,
 | |
|     MandatoryChecksMissingError,
 | |
|     MergeRule,
 | |
|     PostCommentError,
 | |
|     RE_GHSTACK_DESC,
 | |
|     read_merge_rules,
 | |
|     remove_job_name_suffix,
 | |
|     sha_from_committed_event,
 | |
|     sha_from_force_push_after,
 | |
|     validate_revert,
 | |
| )
 | |
| 
 | |
| 
 | |
| if "GIT_REMOTE_URL" not in os.environ:
 | |
|     os.environ["GIT_REMOTE_URL"] = "https://github.com/pytorch/pytorch"
 | |
| 
 | |
| GQL_MOCKS = "gql_mocks.json.gz"
 | |
| DRCI_MOCKS = "drci_mocks.json.gz"
 | |
| 
 | |
| 
 | |
| def mock_query(
 | |
|     fallback_function: Any,
 | |
|     file_name: str,
 | |
|     key_function: Any,
 | |
|     *args: Any,
 | |
| ) -> Any:
 | |
|     gql_db_fname = os.path.join(os.path.dirname(__file__), file_name)
 | |
| 
 | |
|     def get_mocked_queries() -> Any:
 | |
|         if not os.path.exists(gql_db_fname):
 | |
|             return {}
 | |
|         with gzip.open(gql_db_fname, encoding="utf-8", mode="rt") as f:
 | |
|             return json.load(f)
 | |
| 
 | |
|     def save_mocked_queries(obj: Any) -> None:
 | |
|         with gzip.open(gql_db_fname, encoding="utf-8", mode="wt") as f:
 | |
|             json.dump(obj, f, indent=2)
 | |
|             f.write("\n")
 | |
| 
 | |
|     key = key_function(*args)
 | |
|     mocked_queries = get_mocked_queries()
 | |
| 
 | |
|     if key in mocked_queries:
 | |
|         return mocked_queries[key]
 | |
| 
 | |
|     # TODO: Remove me once https://github.com/pytorch/pytorch/issues/160489 is resolved
 | |
|     raise ValueError(f"Key {key} could not be found in gql_mocks")
 | |
| 
 | |
|     try:
 | |
|         rc = fallback_function(*args)
 | |
|     except HTTPError as err:
 | |
|         if err.code == 401 or err.code == 403:
 | |
|             err_msg = f"If you are seeing this message during workflow run, please make sure to update {file_name}"
 | |
|             err_msg += f" locally, by deleting it and running {os.path.basename(__file__)} with"
 | |
|             err_msg += " GitHub Personal Access Token passed via GITHUB_TOKEN"
 | |
|             err_msg += " and drci api key passed via DRCI_BOT_KEY environment variables"
 | |
|             if os.getenv("GITHUB_TOKEN") is None or os.getenv("DRCI_BOT_KEY") is None:
 | |
|                 err_msg = (
 | |
|                     "Failed to update cached queries as GITHUB_TOKEN or DRCI_BOT_KEY "
 | |
|                     + "is not defined. "
 | |
|                     + err_msg
 | |
|                 )
 | |
|             raise RuntimeError(err_msg) from err
 | |
|     mocked_queries[key] = rc
 | |
| 
 | |
|     save_mocked_queries(mocked_queries)
 | |
| 
 | |
|     return rc
 | |
| 
 | |
| 
 | |
| def mocked_gh_graphql(query: str, **kwargs: Any) -> Any:
 | |
|     def key_function(query: str, kwargs: Any) -> str:
 | |
|         return f"query_sha={sha256(query.encode('utf-8')).hexdigest()} " + " ".join(
 | |
|             [f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]
 | |
|         )
 | |
| 
 | |
|     def gh_graphql_wrapper(query: str, kwargs: Any) -> Any:
 | |
|         return gh_graphql(query, **kwargs)
 | |
| 
 | |
|     return mock_query(gh_graphql_wrapper, GQL_MOCKS, key_function, query, kwargs)
 | |
| 
 | |
| 
 | |
| def mocked_drci_classifications(pr_num: int, project: str, num_retries: int = 3) -> Any:
 | |
|     return mock_query(
 | |
|         get_drci_classifications,
 | |
|         DRCI_MOCKS,
 | |
|         lambda x, y: f"{x} {y}",
 | |
|         pr_num,
 | |
|         project,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
 | |
|     class Object:
 | |
|         def __init__(self) -> None:
 | |
|             self.revert = revert
 | |
|             self.force = force
 | |
|             self.pr_num = 76123
 | |
|             self.dry_run = True
 | |
|             self.comment_id = 12345  # Set to non-zero value
 | |
|             self.reason = "this is for testing"
 | |
|             self.ignore_current = False
 | |
|             self.check_mergeability = False
 | |
| 
 | |
|     return Object()
 | |
| 
 | |
| 
 | |
| def mock_remove_label(
 | |
|     org: str, repo: str, pr_num: str, label: str, dry_run: bool
 | |
| ) -> None:
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def mock_revert(
 | |
|     repo: GitRepo,
 | |
|     pr: GitHubPR,
 | |
|     *,
 | |
|     dry_run: bool = False,
 | |
|     comment_id: Optional[int] = None,
 | |
|     reason: Optional[str] = None,
 | |
| ) -> None:
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def mock_merge(
 | |
|     pr: GitHubPR,
 | |
|     repo: GitRepo,
 | |
|     comment_id: int,
 | |
|     dry_run: bool = False,
 | |
|     skip_mandatory_checks: bool = False,
 | |
|     timeout_minutes: int = 400,
 | |
|     stale_pr_days: int = 3,
 | |
|     ignore_current: bool = False,
 | |
| ) -> None:
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def mock_gh_get_info() -> Any:
 | |
|     return {
 | |
|         "closed": False,
 | |
|         "isCrossRepository": False,
 | |
|         "headRefName": "foo",
 | |
|         "baseRefName": "bar",
 | |
|         "baseRepository": {"defaultBranchRef": {"name": "bar"}},
 | |
|         "files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
 | |
|         "changedFiles": 0,
 | |
|     }
 | |
| 
 | |
| 
 | |
| def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> list[MergeRule]:
 | |
|     return [
 | |
|         MergeRule(
 | |
|             name="mock with nonexistent check",
 | |
|             patterns=["*"],
 | |
|             approved_by=[],
 | |
|             mandatory_checks_name=["Lint", "Facebook CLA Check", "nonexistent"],
 | |
|             ignore_flaky_failures=True,
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def mocked_read_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
 | |
|     return [
 | |
|         MergeRule(
 | |
|             name="super",
 | |
|             patterns=["*"],
 | |
|             approved_by=["pytorch/metamates", "ngimel"],
 | |
|             mandatory_checks_name=[
 | |
|                 "Lint",
 | |
|                 "pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
 | |
|             ],
 | |
|             ignore_flaky_failures=True,
 | |
|         ),
 | |
|         MergeRule(
 | |
|             name="xla",
 | |
|             patterns=[".github/ci_commit_pins/xla.txt"],
 | |
|             approved_by=["pytorchbot"],
 | |
|             mandatory_checks_name=[
 | |
|                 "Lint",
 | |
|                 "EasyCLA",
 | |
|                 "pull / linux-focal-py3_8-clang9-xla / build",
 | |
|                 "pull / linux-focal-py3_8-clang9-xla / test (xla, 1, 1, linux.12xlarge)",
 | |
|             ],
 | |
|             ignore_flaky_failures=True,
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def mocked_read_merge_rules_approvers(
 | |
|     repo: Any, org: str, project: str
 | |
| ) -> list[MergeRule]:
 | |
|     return [
 | |
|         MergeRule(
 | |
|             name="Core Reviewers",
 | |
|             patterns=["*"],
 | |
|             approved_by=["1", "2", "3", "4", "5", "6"],
 | |
|             mandatory_checks_name=[
 | |
|                 "Lint",
 | |
|                 "pull",
 | |
|             ],
 | |
|         ),
 | |
|         MergeRule(
 | |
|             name="Core Maintainers",
 | |
|             patterns=["*"],
 | |
|             approved_by=["1", "2", "malfet"],
 | |
|             mandatory_checks_name=[
 | |
|                 "Lint",
 | |
|                 "pull",
 | |
|             ],
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> list[MergeRule]:
 | |
|     raise RuntimeError("testing")
 | |
| 
 | |
| 
 | |
| def xla_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
 | |
|     return [
 | |
|         MergeRule(
 | |
|             name=" OSS CI / pytorchbot / XLA",
 | |
|             patterns=[".github/ci_commit_pins/xla.txt"],
 | |
|             approved_by=["pytorchbot"],
 | |
|             mandatory_checks_name=[
 | |
|                 "Lint",
 | |
|                 "EasyCLA",
 | |
|                 "pull / linux-bionic-py3_8-clang8-xla / build",
 | |
|                 "pull / linux-bionic-py3_8-clang8-xla / test (xla, 1, 1, linux.4xlarge)",
 | |
|                 "inductor / cuda11.8-py3.10-gcc7-sm86 / test (inductor_torchbench_dynamic, 1, 1, linux.g5.4xlarge.nvidia.gpu)",
 | |
|             ],
 | |
|             ignore_flaky_failures=False,
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
| 
 | |
| class DummyGitRepo(GitRepo):
 | |
|     def __init__(self) -> None:
 | |
|         super().__init__(get_git_repo_dir(), get_git_remote_name())
 | |
| 
 | |
|     def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
 | |
|         return ["FakeCommitSha"]
 | |
| 
 | |
|     def commit_message(self, ref: str) -> str:
 | |
|         return "super awesome commit message"
 | |
| 
 | |
| 
 | |
| @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | |
| @mock.patch(
 | |
|     "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | |
| )
 | |
| class TestTryMerge(TestCase):
 | |
|     def test_merge_rules_valid(self, *args: Any) -> None:
 | |
|         "Test that merge_rules.yaml can be parsed"
 | |
|         repo = DummyGitRepo()
 | |
|         merge_rules = read_merge_rules(repo, "pytorch", "pytorch")
 | |
|         self.assertGreater(len(merge_rules), 1)
 | |
| 
 | |
|     @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
 | |
|     def test_match_rules(self, *args: Any) -> None:
 | |
|         "Tests that PR passes merge rules"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 109999)
 | |
|         repo = DummyGitRepo()
 | |
|         self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
 | |
| 
 | |
|     @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_raise)
 | |
|     def test_read_merge_rules_fails(self, *args: Any) -> None:
 | |
|         "Tests that PR fails to read the merge rules"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 77700)
 | |
|         repo = DummyGitRepo()
 | |
|         self.assertRaisesRegex(
 | |
|             RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo)
 | |
|         )
 | |
| 
 | |
|     @mock.patch(
 | |
|         "trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_approvers
 | |
|     )
 | |
|     def test_match_rules_approvers(self, *args: Any) -> None:
 | |
|         "Tests that PR has the necessary approvers"
 | |
|         repo = DummyGitRepo()
 | |
| 
 | |
|         pr = GitHubPR("pytorch", "pytorch", 115329)
 | |
|         # Test that all potential approvers across all rules are listed if the
 | |
|         # PR doesn't have one of them
 | |
|         for mock_rule in ["Core Reviewers", "Core Maintainers"]:
 | |
|             self.assertRaisesRegex(
 | |
|                 RuntimeError,
 | |
|                 mock_rule,
 | |
|                 lambda: find_matching_merge_rule(pr, repo),
 | |
|             )
 | |
| 
 | |
|         pr = GitHubPR("pytorch", "pytorch", 115495)
 | |
|         # Test that PR with the correct approvers doesn't raise any exception
 | |
|         self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
 | |
| 
 | |
|     @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
 | |
|     def test_lint_fails(self, *args: Any) -> None:
 | |
|         "Tests that PR fails mandatory lint check"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 90791)
 | |
|         repo = DummyGitRepo()
 | |
|         self.assertRaises(RuntimeError, lambda: find_matching_merge_rule(pr, repo))
 | |
| 
 | |
|     def test_get_last_comment(self, *args: Any) -> None:
 | |
|         "Tests that last comment can be fetched"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 71759)
 | |
|         comment = pr.get_last_comment()
 | |
|         self.assertEqual(comment.author_login, "github-actions")
 | |
|         self.assertIsNone(comment.editor_login)
 | |
|         self.assertTrue("You've committed this PR" in comment.body_text)
 | |
| 
 | |
|     def test_get_author_null(self, *args: Any) -> None:
 | |
|         """Tests that PR author can be computed
 | |
|         If reply contains NULL
 | |
|         """
 | |
|         pr = GitHubPR("pytorch", "pytorch", 71759)
 | |
|         author = pr.get_author()
 | |
|         self.assertTrue(author is not None)
 | |
|         self.assertTrue("@" in author)
 | |
|         self.assertTrue(pr.get_diff_revision() is None)
 | |
| 
 | |
|         # PR with multiple contributors, but creator id is not among authors
 | |
|         pr = GitHubPR("pytorch", "pytorch", 75095)
 | |
|         self.assertEqual(pr.get_pr_creator_login(), "mruberry")
 | |
|         author = pr.get_author()
 | |
|         self.assertTrue(author is not None)
 | |
| 
 | |
|     def test_large_diff(self, *args: Any) -> None:
 | |
|         "Tests that PR with 100+ files can be fetched"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 73099)
 | |
|         self.assertTrue(pr.get_changed_files_count() > 100)
 | |
|         flist = pr.get_changed_files()
 | |
|         self.assertEqual(len(flist), pr.get_changed_files_count())
 | |
| 
 | |
|     def test_internal_changes(self, *args: Any) -> None:
 | |
|         "Tests that PR with internal changes is detected"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 110140)
 | |
|         self.assertTrue(pr.has_internal_changes())
 | |
| 
 | |
|     def test_comments_pagination(self, *args: Any) -> None:
 | |
|         "Tests that PR with 50+ comments can be fetched"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 31093)
 | |
|         self.assertGreater(len(pr.get_comments()), 50)
 | |
| 
 | |
|     def test_gql_complexity(self, *args: Any) -> None:
 | |
|         "Fetch comments and conclusions for PR with 60 commits"
 | |
|         # Previous version of GrapQL query used to cause HTTP/502 error
 | |
|         # see https://gist.github.com/malfet/9b93bc7eeddeaf1d84546efc4f0c577f
 | |
|         pr = GitHubPR("pytorch", "pytorch", 68111)
 | |
|         self.assertGreater(len(pr.get_comments()), 20)
 | |
|         # NS(09/27/2023): GitHub seems to recycle older checkruns
 | |
|         # https://github.com/pytorch/pytorch/pull/68111/checks shows 0 runs
 | |
|         # self.assertGreater(len(pr.get_checkrun_conclusions()), 3)
 | |
|         self.assertGreater(pr.get_commit_count(), 60)
 | |
| 
 | |
|     @skip("GitHub doesn't keep this data anymore")
 | |
|     def test_gql_retrieve_checksuites(self, *args: Any) -> None:
 | |
|         "Fetch comments and conclusions for PR with 60 commits"
 | |
|         pr = GitHubPR("pytorch", "pytorch", 94787)
 | |
|         self.assertEqual(len(pr.get_checkrun_conclusions()), 182)
 | |
| 
 | |
|     def test_team_members(self, *args: Any) -> None:
 | |
|         "Test fetching team members works"
 | |
|         dev_infra_team = gh_get_team_members("pytorch", "pytorch-dev-infra")
 | |
|         self.assertGreater(len(dev_infra_team), 2)
 | |
|         with self.assertWarns(Warning):
 | |
|             non_existing_team = gh_get_team_members("pytorch", "qwertyuiop")
 | |
|             self.assertEqual(len(non_existing_team), 0)
 | |
| 
 | |
|     def test_get_author_many_commits(self, *args: Any) -> None:
 | |
|         """Tests that authors for all commits can be fetched"""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 76118)
 | |
|         authors = pr.get_authors()
 | |
|         self.assertGreater(pr.get_commit_count(), 100)
 | |
|         self.assertGreater(len(authors), 50)
 | |
|         self.assertTrue("@" in pr.get_author())
 | |
| 
 | |
|     @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_NE)
 | |
|     def test_pending_status_check(self, *args: Any) -> None:
 | |
|         """Tests that PR with nonexistent/pending status checks fails with the right reason."""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 76118)
 | |
|         repo = DummyGitRepo()
 | |
|         self.assertRaisesRegex(
 | |
|             MandatoryChecksMissingError,
 | |
|             ".*are pending/not yet run.*",
 | |
|             lambda: find_matching_merge_rule(pr, repo),
 | |
|         )
 | |
| 
 | |
|     def test_get_author_many_reviews(self, *args: Any) -> None:
 | |
|         """Tests that all reviews can be fetched"""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 76123)
 | |
|         approved_by = pr.get_approved_by()
 | |
|         self.assertGreater(len(approved_by), 0)
 | |
|         assert pr._reviews is not None  # to pacify mypy
 | |
|         self.assertGreater(len(pr._reviews), 100)
 | |
| 
 | |
|     def get_co_authors(self, *args: Any) -> None:
 | |
|         """Tests that co-authors are recognized"""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 118347)
 | |
|         authors = pr.get_authors()
 | |
|         self.assertIn("kit1980", authors)
 | |
|         self.assertIn("Co-authored-by:", pr.gen_commit_message())
 | |
| 
 | |
|     def test_get_checkruns_many_runs(self, *args: Any) -> None:
 | |
|         """Tests that all checkruns can be fetched"""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 105260)
 | |
|         conclusions = pr.get_checkrun_conclusions()
 | |
|         self.assertEqual(len(conclusions), 221)
 | |
|         self.assertTrue(
 | |
|             "pull / linux-docs / build-docs-cpp-false" in conclusions.keys()
 | |
|         )
 | |
| 
 | |
|     def test_cancelled_gets_ignored(self, *args: Any) -> None:
 | |
|         """Tests that cancelled workflow does not override existing successful status"""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 110367)
 | |
|         conclusions = pr.get_checkrun_conclusions()
 | |
|         lint_checks = [name for name in conclusions.keys() if "Lint" in name]
 | |
|         self.assertTrue(len(lint_checks) > 0)
 | |
|         self.assertTrue(
 | |
|             all(conclusions[name].status == "SUCCESS" for name in lint_checks)
 | |
|         )
 | |
| 
 | |
|     def test_get_review_comment_by_id(self, *args: Any) -> None:
 | |
|         """Tests that even if the comment requested was actually a review instead of a simple comment, we can still find it"""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 107070)
 | |
|         review_comment_id = 1582767635
 | |
|         comment = pr.get_comment_by_id(review_comment_id)
 | |
|         self.assertIsNotNone(comment)
 | |
| 
 | |
|     @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
 | |
|     @mock.patch("trymerge.parse_args", return_value=mock_parse_args(True, False))
 | |
|     @mock.patch("trymerge.try_revert", side_effect=mock_revert)
 | |
|     def test_main_revert(self, mock_revert: Any, *args: Any) -> None:
 | |
|         trymerge_main()
 | |
|         mock_revert.assert_called_once()
 | |
| 
 | |
|     @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
 | |
|     @mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, True))
 | |
|     @mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
 | |
|     @mock.patch("trymerge.merge", side_effect=mock_merge)
 | |
|     def test_main_force(
 | |
|         self, mock_merge: Any, mock_parse_args: Any, *args: Any
 | |
|     ) -> None:
 | |
|         trymerge_main()
 | |
|         mock_merge.assert_called_once_with(
 | |
|             mock.ANY,
 | |
|             mock.ANY,
 | |
|             comment_id=mock.ANY,
 | |
|             dry_run=mock.ANY,
 | |
|             skip_mandatory_checks=True,
 | |
|             ignore_current=False,
 | |
|         )
 | |
| 
 | |
|     @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
 | |
|     @mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, False))
 | |
|     @mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
 | |
|     @mock.patch("trymerge.merge", side_effect=mock_merge)
 | |
|     def test_main_merge(self, mock_merge: Any, *args: Any) -> None:
 | |
|         trymerge_main()
 | |
|         mock_merge.assert_called_once_with(
 | |
|             mock.ANY,
 | |
|             mock.ANY,
 | |
|             comment_id=mock.ANY,
 | |
|             dry_run=mock.ANY,
 | |
|             skip_mandatory_checks=False,
 | |
|             ignore_current=False,
 | |
|         )
 | |
| 
 | |
|     @mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
 | |
|     def test_revert_rules(self, *args: Any) -> None:
 | |
|         """Tests that reverts from collaborators are allowed"""
 | |
|         pr = GitHubPR("pytorch", "pytorch", 79694)
 | |
|         repo = DummyGitRepo()
 | |
|         self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845))
 | |
| 
 | |
|     def test_get_changed_files(self, *args: Any) -> None:
 | |
|         """
 | |
|         Tests that the list changed files in a PR doesn't include duplicates
 | |
|         """
 | |
|         pr = GitHubPR("pytorch", "pytorch", 95233)
 | |
|         try:
 | |
|             changed_files = pr.get_changed_files()
 | |
|         except RuntimeError as error:
 | |
|             self.fail(f"get_changed_files throws an exception: {error}")
 | |
| 
 | |
|         self.assertEqual(len(changed_files), pr.get_changed_files_count())
 | |
| 
 | |
|     def test_revert_codev_abandoned_diff_succeeds(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 100652)
 | |
| 
 | |
|         class GitRepoCoDev(DummyGitRepo):
 | |
|             def commit_message(self, ref: str) -> str:
 | |
|                 return pr.get_body()
 | |
| 
 | |
|         repo = GitRepoCoDev()
 | |
|         validate_revert(repo, pr, comment_id=1588195237)
 | |
| 
 | |
|     def test_pr_changed_submodule_detection(self, *args: Any) -> None:
 | |
|         # Updates submodule during dev-cycle but reverts it later
 | |
|         pr = GitHubPR("pytorch", "pytorch", 95045)
 | |
|         self.assertEqual(pr.get_changed_submodules(), [])
 | |
|         self.assertFalse(pr.has_invalid_submodule_updates())
 | |
| 
 | |
|         # PR updates ideep
 | |
|         pr = GitHubPR("pytorch", "pytorch", 94939)
 | |
|         self.assertEqual(pr.get_changed_submodules(), ["third_party/ideep"])
 | |
|         self.assertTrue(pr.has_invalid_submodule_updates())
 | |
| 
 | |
|         # Automated submodule update
 | |
|         pr = GitHubPR("pytorch", "pytorch", 91051)
 | |
|         self.assertEqual(pr.get_changed_submodules(), ["third_party/kineto"])
 | |
|         self.assertFalse(pr.has_invalid_submodule_updates())
 | |
| 
 | |
|     def test_remove_job_name_suffix(self, *args: Any) -> None:
 | |
|         test_cases = [
 | |
|             {
 | |
|                 "name": "linux-bionic-cuda12.6-py3.10-gcc9-sm86 / test (default, 1, 5, linux.g5.4xlarge.nvidia.gpu)",
 | |
|                 "expected": "linux-bionic-cuda12.6-py3.10-gcc9-sm86 / test (default)",
 | |
|             },
 | |
|             {
 | |
|                 "name": "android-emulator-build-test / build-and-test (default, 1, 1, ubuntu-20.04-16x)",
 | |
|                 "expected": "android-emulator-build-test / build-and-test (default)",
 | |
|             },
 | |
|             {
 | |
|                 "name": "linux-focal-rocm5.4.2-py3.8 / build",
 | |
|                 "expected": "linux-focal-rocm5.4.2-py3.8 / build",
 | |
|             },
 | |
|             {
 | |
|                 "name": "libtorch-cpu-shared-with-deps-release-build",
 | |
|                 "expected": "libtorch-cpu-shared-with-deps-release-build",
 | |
|             },
 | |
|             {
 | |
|                 "name": "manywheel-py3_8-cuda11_8-test / test",
 | |
|                 "expected": "manywheel-py3_8-cuda11_8-test / test",
 | |
|             },
 | |
|             {
 | |
|                 "name": "lintrunner / linux-job",
 | |
|                 "expected": "lintrunner / linux-job",
 | |
|             },
 | |
|             {
 | |
|                 "name": "Test `run_test.py` is usable without boto3",
 | |
|                 "expected": "Test `run_test.py` is usable without boto3",
 | |
|             },
 | |
|         ]
 | |
| 
 | |
|         for case in test_cases:
 | |
|             self.assertEqual(case["expected"], remove_job_name_suffix(case["name"]))
 | |
| 
 | |
|     def test_get_merge_base(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 104121)
 | |
| 
 | |
|         mock_merge_base = "mocked-sha"
 | |
|         with mock.patch(
 | |
|             "trymerge.gh_fetch_merge_base", return_value=mock_merge_base
 | |
|         ) as mocked_gh_fetch_merge_base:
 | |
|             self.assertEqual(mock_merge_base, pr.get_merge_base())
 | |
| 
 | |
|             # Make sure that consecutive calls will use the same merge base instead of
 | |
|             # making another query
 | |
|             self.assertEqual(mock_merge_base, pr.get_merge_base())
 | |
|             mocked_gh_fetch_merge_base.assert_called_once()
 | |
| 
 | |
|     def test_app_can_revert(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 164660)
 | |
|         repo = DummyGitRepo()
 | |
|         app_comment_id, impostor_comment_id = 3375785595, 3377647892
 | |
|         # Check that app can revert
 | |
|         self.assertIsNotNone(validate_revert(repo, pr, comment_id=app_comment_id))
 | |
|         # But impostor can not
 | |
|         self.assertRaises(
 | |
|             PostCommentError,
 | |
|             lambda: validate_revert(repo, pr, comment_id=impostor_comment_id),
 | |
|         )
 | |
|         # Despite it's name being the name of the bot
 | |
|         self.assertEqual(
 | |
|             pr.get_comment_by_id(impostor_comment_id).author_login,
 | |
|             "pytorch-auto-revert",
 | |
|         )
 | |
| 
 | |
| 
 | |
| @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | |
| @mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | |
| @mock.patch(
 | |
|     "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | |
| )
 | |
| class TestBypassFailures(TestCase):
 | |
|     def test_get_classifications(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 109584)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             checks[
 | |
|                 "pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
 | |
|             ].classification
 | |
|             == "BROKEN_TRUNK"
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             checks[
 | |
|                 "trunk / win-vs2019-cpu-py3 / test (default, 2, 3, windows.4xlarge.nonephemeral)"
 | |
|             ].classification
 | |
|             == "FLAKY"
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             checks[
 | |
|                 "pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge)"
 | |
|             ].classification
 | |
|             == "FLAKY"
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             checks[
 | |
|                 "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
 | |
|             ].classification
 | |
|             == "FLAKY"
 | |
|         )
 | |
| 
 | |
|         # Set the threshold larger or equal to the number of ok failures
 | |
|         pending, failed, ignorable = categorize_checks(
 | |
|             checks, list(checks.keys()), ok_failed_checks_threshold=6
 | |
|         )
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | |
| 
 | |
|         # Not set any threshold, defaults to -1 to ignore all flaky and broken trunk failures
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | |
| 
 | |
|         # Set the threshold lower than the number of ok failures
 | |
|         pending, failed, ignorable = categorize_checks(
 | |
|             checks, list(checks.keys()), ok_failed_checks_threshold=1
 | |
|         )
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 6)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | |
| 
 | |
|         # Set the threshold to 0 like when ignore_flaky_failures is on
 | |
|         pending, failed, ignorable = categorize_checks(
 | |
|             checks, list(checks.keys()), ok_failed_checks_threshold=1
 | |
|         )
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 6)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | |
| 
 | |
|     def test_get_classifications_flaky_fullname(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 110362)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | |
| 
 | |
|     def test_get_classifications_invalid_cancel(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 110367)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 0)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
 | |
|         self.assertTrue(len(ignorable["UNSTABLE"]) == 3)
 | |
| 
 | |
|     def test_get_classifications_similar_failures(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 109750)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | |
| 
 | |
|     def test_get_classifications_unstable(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 104312)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         workflow_name = "linux-bionic-cuda12.1-py3.10-gcc9-bazel-test"
 | |
|         job_name = "build-and-test (default, 1, 1, linux.4xlarge.nvidia.gpu, unstable)"
 | |
|         self.assertTrue(
 | |
|             checks[f"pull / {workflow_name} / {job_name}"].classification == "UNSTABLE"
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(
 | |
|             checks, list(checks.keys()), ok_failed_checks_threshold=1
 | |
|         )
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
 | |
| 
 | |
|         # Add another test case where there is no unstable keyword in the job name, but
 | |
|         # the job has already been marked as unstable
 | |
|         pr = GitHubPR("pytorch", "executorch", 3318)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         print(checks)
 | |
|         workflow_name = "test-llama-app"
 | |
|         job_name = "mobile-job (android)"
 | |
|         self.assertTrue(
 | |
|             checks[f"Android / {workflow_name} / {job_name}"].classification
 | |
|             == "UNSTABLE"
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(
 | |
|             checks, list(checks.keys()), ok_failed_checks_threshold=1
 | |
|         )
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
 | |
| 
 | |
|     def test_get_classifications_broken_trunk(self, *args: Any) -> None:
 | |
|         # The mock merge base is the actual value returned by gh_fetch_merge_base
 | |
|         test_cases = [
 | |
|             {
 | |
|                 # This PR had one broken trunk failure but it was run on a different shard
 | |
|                 # than the one on the base commit. This should still count as broken trunk
 | |
|                 "pr_num": 104214,
 | |
|                 "related_failure_count": 0,
 | |
|                 "flaky_or_broken_trunk": 1,
 | |
|             },
 | |
|             {
 | |
|                 # This PR had one broken trunk failure and it used ghstack
 | |
|                 "pr_num": 105145,
 | |
|                 "related_failure_count": 0,
 | |
|                 "flaky_or_broken_trunk": 1,
 | |
|             },
 | |
|             {
 | |
|                 # The failure on the merge base was retried successfully and
 | |
|                 # its conclusion changed from failure to success. We want to
 | |
|                 # keep the failure record from the merge base so that it can
 | |
|                 # be used to detect broken trunk
 | |
|                 "pr_num": 107160,
 | |
|                 "related_failure_count": 0,
 | |
|                 "flaky_or_broken_trunk": 1,
 | |
|             },
 | |
|             {
 | |
|                 # This PR used Dr.CI broken trunk classification
 | |
|                 "pr_num": 111253,
 | |
|                 "related_failure_count": 1,
 | |
|                 "flaky_or_broken_trunk": 1,
 | |
|             },
 | |
|         ]
 | |
| 
 | |
|         for case in test_cases:
 | |
|             pr_num = case["pr_num"]
 | |
|             related_failure_count = case["related_failure_count"]
 | |
|             flaky_or_broken_trunk = case["flaky_or_broken_trunk"]
 | |
| 
 | |
|             pr = GitHubPR("pytorch", "pytorch", pr_num)
 | |
|             checks = pr.get_checkrun_conclusions()
 | |
|             checks = get_classifications(
 | |
|                 pr.pr_num,
 | |
|                 pr.project,
 | |
|                 checks,
 | |
|                 [],
 | |
|             )
 | |
| 
 | |
|             pending, failed, _ = categorize_checks(checks, list(checks.keys()))
 | |
|             self.assertTrue(len(pending) == 0)
 | |
|             self.assertTrue(len(failed) == related_failure_count)
 | |
| 
 | |
|             # When the ok_failed_checks_threshold is set to 0, the broken trunk failure
 | |
|             # won't be ignored
 | |
|             pending, failed, _ = categorize_checks(
 | |
|                 checks, list(checks.keys()), ok_failed_checks_threshold=0
 | |
|             )
 | |
|             self.assertTrue(len(pending) == 0)
 | |
|             self.assertTrue(
 | |
|                 len(failed) == flaky_or_broken_trunk + related_failure_count
 | |
|             )
 | |
| 
 | |
|     def test_ignore_current(self, *args: Any) -> None:
 | |
|         # Test various interactions of the failure classifier to ensure that ignore
 | |
|         # current checks takes place after other classifications: flaky, unstable,
 | |
|         # or broken trunk. Only actual new failures should be kept in the list of
 | |
|         # ignore current checks to use to record force merge with actual failures
 | |
|         flaky = "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
 | |
|         broken_trunk = (
 | |
|             "pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
 | |
|         )
 | |
| 
 | |
|         pr = GitHubPR("pytorch", "pytorch", 109584)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
| 
 | |
|         # Known flaky failure takes precedence over ignore current (need to set the
 | |
|         # merge base here to get the results from Dr. CI, and that categorize the
 | |
|         # broken trunk failure too
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [broken_trunk, flaky],
 | |
|         )
 | |
|         self.assertTrue(checks[flaky].classification == "FLAKY")
 | |
|         self.assertTrue(checks[broken_trunk].classification == "BROKEN_TRUNK")
 | |
|         _, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["IGNORE_CURRENT_CHECK"]) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 4)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
 | |
| 
 | |
|     def test_get_classifications_wrong_workflow_name(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 123104)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
| 
 | |
|         check_name = "linux-binary-conda / conda-py3_8-cuda11_8-build / build"
 | |
|         check_name_workflow_path = ".github/workflows/generated-linux-binary-conda-nightly.yml / conda-py3_8-cuda11_8-build / build"
 | |
| 
 | |
|         # Mock a check where the workflow name uses the full path
 | |
|         checks[check_name_workflow_path] = JobCheckState(
 | |
|             check_name_workflow_path,
 | |
|             checks[check_name].url,
 | |
|             checks[check_name].status,
 | |
|             checks[check_name].classification,
 | |
|             checks[check_name].job_id,
 | |
|             checks[check_name].title,
 | |
|             checks[check_name].summary,
 | |
|         )
 | |
|         del checks[check_name]
 | |
| 
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(
 | |
|             checks,
 | |
|             list(checks.keys()),
 | |
|         )
 | |
| 
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
 | |
| 
 | |
|     def test_ignore_failures_older_run_same_workflow(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 129013)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(
 | |
|             checks,
 | |
|             list(checks.keys()),
 | |
|         )
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 2)
 | |
|         self.assertTrue(len(ignorable["UNSTABLE"]) == 13)
 | |
| 
 | |
|     @mock.patch("trymerge.read_merge_rules", side_effect=xla_merge_rules)
 | |
|     def test_dont_ignore_flaky_failures(self, *args: Any) -> None:
 | |
|         """
 | |
|         Regression test for https://github.com/pytorch/test-infra/issues/4126
 | |
|         """
 | |
|         pr = GitHubPR("pytorch", "pytorch", 105312)
 | |
|         repo = DummyGitRepo()
 | |
|         # Check that failure is classified as flaky but still raises exception
 | |
|         with warnings.catch_warnings(record=True) as w, self.assertRaises(RuntimeError):
 | |
|             find_matching_merge_rule(pr, repo)
 | |
|         self.assertEqual(len(w), 1)
 | |
|         self.assertIn(
 | |
|             "1 checks failed but were likely due flakiness or broken trunk",
 | |
|             str(w[0].message),
 | |
|         )
 | |
| 
 | |
| 
 | |
| @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | |
| @mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | |
| @mock.patch("trymerge.get_drci_classifications", return_value={})
 | |
| class TestBypassFailuresOnSandCastle(TestCase):
 | |
|     def test_get_classifications(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 111467)
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 0)
 | |
|         self.assertTrue(len(ignorable["FLAKY"]) == 1)
 | |
|         self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 1)
 | |
| 
 | |
|     def test_get_classifications_drci_checkrun_not_found(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 111467)
 | |
| 
 | |
|         # No summary
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks[DRCI_CHECKRUN_NAME] = JobCheckState(
 | |
|             DRCI_CHECKRUN_NAME,
 | |
|             "",
 | |
|             "NEUTRAL",
 | |
|             None,
 | |
|             1,
 | |
|             "",
 | |
|             None,
 | |
|         )
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 2)
 | |
| 
 | |
|         # Empty summary
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         checks[DRCI_CHECKRUN_NAME] = JobCheckState(
 | |
|             DRCI_CHECKRUN_NAME,
 | |
|             "",
 | |
|             "NEUTRAL",
 | |
|             None,
 | |
|             1,
 | |
|             "",
 | |
|             "",
 | |
|         )
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 2)
 | |
| 
 | |
|         # No Dr.CI checkrun
 | |
|         checks = pr.get_checkrun_conclusions()
 | |
|         del checks[DRCI_CHECKRUN_NAME]
 | |
|         checks = get_classifications(
 | |
|             pr.pr_num,
 | |
|             pr.project,
 | |
|             checks,
 | |
|             [],
 | |
|         )
 | |
|         pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
 | |
|         self.assertTrue(len(pending) == 0)
 | |
|         self.assertTrue(len(failed) == 2)
 | |
| 
 | |
| 
 | |
| @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | |
| @mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | |
| @mock.patch(
 | |
|     "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | |
| )
 | |
| class TestGitHubPRGhstackDependencies(TestCase):
 | |
|     def test_pr_dependencies(self, *args: Any) -> None:
 | |
|         pr = GitHubPR("pytorch", "pytorch", 106068)
 | |
|         msg = pr.gen_commit_message(filter_ghstack=True)
 | |
|         self.assertEqual(
 | |
|             msg,
 | |
|             f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
 | |
|             "Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
 | |
|             "Approved by: https://github.com/ezyang, https://github.com/fegin\n",
 | |
|         )
 | |
| 
 | |
|     def test_pr_dependencies_ghstack(self, *args: Any) -> None:
 | |
|         pr0 = GitHubPR("pytorch", "pytorch", 106032)
 | |
|         pr1 = GitHubPR("pytorch", "pytorch", 106033)
 | |
|         pr2 = GitHubPR("pytorch", "pytorch", 106034)
 | |
|         pr = GitHubPR("pytorch", "pytorch", 106068)
 | |
|         msg = pr.gen_commit_message(filter_ghstack=True, ghstack_deps=[pr0, pr1, pr2])
 | |
|         self.assertEqual(
 | |
|             msg,
 | |
|             f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
 | |
|             "Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
 | |
|             "Approved by: https://github.com/ezyang, https://github.com/fegin\n"
 | |
|             "ghstack dependencies: #106032, #106033, #106034\n",
 | |
|         )
 | |
| 
 | |
|     @skip(
 | |
|         reason="This test is run against a mutable PR that has changed, so it no longer works. The test should be changed"
 | |
|     )
 | |
|     @mock.patch("trymerge.read_merge_rules")
 | |
|     @mock.patch("trymerge.GitRepo")
 | |
|     @mock.patch("trymerge.get_ghstack_prs")
 | |
|     def test_merge_ghstack_into(
 | |
|         self,
 | |
|         mock_get_ghstack_prs: mock.MagicMock,
 | |
|         mock_repo: mock.MagicMock,
 | |
|         mock_merge_rules: mock.MagicMock,
 | |
|         *args: Any,
 | |
|     ) -> None:
 | |
|         """
 | |
|         Test that the merge_ghstack_into method works correctly
 | |
|         """
 | |
|         pr0 = GitHubPR("pytorch", "pytorch", 106032)
 | |
|         pr1 = GitHubPR("pytorch", "pytorch", 106033)
 | |
|         pr2 = GitHubPR("pytorch", "pytorch", 106034)
 | |
|         pr = GitHubPR("pytorch", "pytorch", 106068)
 | |
| 
 | |
|         # note: in reverse order (e.g. self.pr is the last commit, top of the stack)
 | |
|         mock_get_ghstack_prs.return_value = [
 | |
|             (pr0, "rev0"),
 | |
|             (pr1, "rev1"),
 | |
|             (pr2, "rev2"),
 | |
|             (pr, "rev123"),
 | |
|         ]
 | |
| 
 | |
|         mock_merge_rules.return_value = [
 | |
|             MergeRule(
 | |
|                 "Mock title", patterns=["*"], approved_by=[], mandatory_checks_name=None
 | |
|             )
 | |
|         ]
 | |
| 
 | |
|         mock_repo.cherry_pick.return_value = None
 | |
|         mock_repo.amend_commit_message.return_value = None
 | |
| 
 | |
|         # Call the method under test
 | |
|         res = pr.merge_ghstack_into(mock_repo, True)
 | |
| 
 | |
|         self.assertEqual(res, [pr2, pr])
 | |
| 
 | |
|         mock_repo.cherry_pick.assert_any_call("rev2")
 | |
|         mock_repo.cherry_pick.assert_any_call("rev123")
 | |
| 
 | |
|         self.assertTrue(mock.call("rev1") not in mock_repo.cherry_pick.call_args_list)
 | |
| 
 | |
|         # Verify the first call
 | |
|         message = mock_repo.amend_commit_message.call_args_list[0].args[0]
 | |
|         prefix = (
 | |
|             "[FSDP] Optimize away intermediate `div_` for HSDP (#106034)\n\n\r\n"
 | |
|             "### Background: Gradient Pre-Divide"
 | |
|         )
 | |
|         suffix = (
 | |
|             "\nPull Request resolved: https://github.com/pytorch/pytorch/pull/106034\nApproved by: \nghstack "
 | |
|             "dependencies: #106032, #106033\n"
 | |
|         )
 | |
| 
 | |
|         self.assertTrue(message.startswith(prefix))
 | |
|         self.assertTrue(message.endswith(suffix))
 | |
| 
 | |
|         # Verify the second call
 | |
|         mock_repo.amend_commit_message.assert_any_call(
 | |
|             "[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\n"
 | |
|             "Differential Revision: ["
 | |
|             "D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\n"
 | |
|             "Pull Request resolved: "
 | |
|             "https://github.com/pytorch/pytorch/pull/106068\n"
 | |
|             "Approved by: \n"
 | |
|             "ghstack dependencies: #106032, #106033, #106034\n"
 | |
|         )
 | |
| 
 | |
| 
 | |
| @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | |
| @mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | |
| @mock.patch(
 | |
|     "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | |
| )
 | |
| @mock.patch.object(DummyGitRepo, "commit_message")
 | |
| class TestRevListToPR(TestCase):
 | |
|     # Tests for _revlist_to_prs function
 | |
|     def test__revlist_to_prs_zero_matches(
 | |
|         self, mock_commit_message: mock.MagicMock, *args: Any
 | |
|     ) -> None:
 | |
|         # If zero PRs are mentioned in the commit message, it should raise an error
 | |
|         pr_num = 154098
 | |
|         pr = GitHubPR("pytorch", "pytorch", pr_num)
 | |
|         repo = DummyGitRepo()
 | |
|         mock_commit_message.return_value = "no PRs"
 | |
|         self.assertRaisesRegex(
 | |
|             RuntimeError,
 | |
|             "PRs mentioned in commit dummy: 0.",
 | |
|             lambda: _revlist_to_prs(repo, pr, ["dummy"]),
 | |
|         )
 | |
| 
 | |
|     def test__revlist_to_prs_two_prs(
 | |
|         self, mock_commit_message: mock.MagicMock, *args: Any
 | |
|     ) -> None:
 | |
|         # If two PRs are mentioned in the commit message, it should raise an error
 | |
|         pr_num = 154394
 | |
|         pr = GitHubPR("pytorch", "pytorch", pr_num)
 | |
|         repo = DummyGitRepo()
 | |
|         # https://github.com/pytorch/pytorch/commit/343c56e7650f55fd030aca0b9275d6d73501d3f4
 | |
| 
 | |
|         commit_message = """add sticky cache pgo
 | |
| 
 | |
| ghstack-source-id: 9bc6dee0b427819f978bfabccb72727ba8be2f81
 | |
| Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/154098
 | |
| 
 | |
| ghstack-source-id: 9bc6dee0b427819f978bfabccb72727ba8be2f81
 | |
| Pull Request resolved: https://github.com/pytorch/pytorch/pull/154394"""
 | |
|         mock_commit_message.return_value = commit_message
 | |
|         self.assertRaisesRegex(
 | |
|             RuntimeError,
 | |
|             "PRs mentioned in commit dummy: 2.",
 | |
|             lambda: _revlist_to_prs(repo, pr, ["dummy"]),
 | |
|         )
 | |
| 
 | |
| 
 | |
| @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
 | |
| @mock.patch("trymerge.gh_fetch_merge_base", return_value="")
 | |
| @mock.patch(
 | |
|     "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
 | |
| )
 | |
| class TestTimelineFunctions(TestCase):
 | |
|     """Tests for the new timeline-related functions"""
 | |
| 
 | |
|     def test_sha_from_committed_event(self, *args: Any) -> None:
 | |
|         """Test extracting SHA from committed event"""
 | |
|         # Based on actual GitHub API format - committed events have "sha" at top level
 | |
|         event = {
 | |
|             "event": "committed",
 | |
|             "sha": "fb21ce932ded6670c918804a0d9151b773770a7c",
 | |
|         }
 | |
|         self.assertEqual(
 | |
|             sha_from_committed_event(event), "fb21ce932ded6670c918804a0d9151b773770a7c"
 | |
|         )
 | |
| 
 | |
|         # Test with missing SHA
 | |
|         event_no_sha = {"event": "committed"}
 | |
|         self.assertIsNone(sha_from_committed_event(event_no_sha))
 | |
| 
 | |
|     def test_sha_from_force_push_after(self, *args: Any) -> None:
 | |
|         """Test extracting SHA from force push event"""
 | |
|         # NOTE: The current function doesn't handle the actual GitHub API format
 | |
|         # Real force push events have "commit_id" at top level, but this function
 | |
|         # looks for "after", "after_commit", "after_sha", or "head_sha" fields
 | |
| 
 | |
|         # Test with the legacy format the current function handles
 | |
|         event_legacy = {
 | |
|             "event": "head_ref_force_pushed",
 | |
|             "after": {"sha": "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e"},
 | |
|         }
 | |
|         self.assertEqual(
 | |
|             sha_from_force_push_after(event_legacy),
 | |
|             "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
 | |
|         )
 | |
| 
 | |
|         # Test with current GitHub API format (should return None with current implementation)
 | |
|         event_real_api = {
 | |
|             "event": "head_ref_force_pushed",
 | |
|             "commit_id": "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
 | |
|         }
 | |
|         self.assertEqual(
 | |
|             sha_from_force_push_after(event_real_api),
 | |
|             "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
 | |
|         )  # Current function doesn't handle commit_id
 | |
| 
 | |
|         # Test with missing SHA
 | |
|         event_no_sha = {"event": "head_ref_force_pushed"}
 | |
|         self.assertIsNone(sha_from_force_push_after(event_no_sha))
 | |
| 
 | |
|     @mock.patch("trymerge.gh_fetch_json_list")
 | |
|     def test_iter_issue_timeline_until_comment(
 | |
|         self, mock_gh_fetch_json_list: Any, *args: Any
 | |
|     ) -> None:
 | |
|         """Test timeline iteration until target comment"""
 | |
|         # Mock timeline data based on actual GitHub API format
 | |
|         timeline_data = [
 | |
|             {"event": "commented", "id": 100, "body": "first comment"},
 | |
|             {"event": "committed", "sha": "fb21ce932ded6670c918804a0d9151b773770a7c"},
 | |
|             {"event": "commented", "id": 200, "body": "target comment"},
 | |
|             {"event": "commented", "id": 300, "body": "after target"},
 | |
|         ]
 | |
|         mock_gh_fetch_json_list.return_value = timeline_data
 | |
| 
 | |
|         # Test iteration stops at target comment
 | |
|         events = list(iter_issue_timeline_until_comment("pytorch", "pytorch", 123, 200))
 | |
|         self.assertEqual(len(events), 3)  # Should stop at target comment
 | |
|         self.assertEqual(events[0]["event"], "commented")
 | |
|         self.assertEqual(events[0]["id"], 100)
 | |
|         self.assertEqual(events[1]["event"], "committed")
 | |
|         self.assertEqual(events[1]["sha"], "fb21ce932ded6670c918804a0d9151b773770a7c")
 | |
|         self.assertEqual(events[2]["event"], "commented")
 | |
|         self.assertEqual(events[2]["id"], 200)
 | |
| 
 | |
|     @mock.patch("trymerge.gh_fetch_json_list")
 | |
|     def test_iter_issue_timeline_until_comment_not_found(
 | |
|         self, mock_gh_fetch_json_list: Any, *args: Any
 | |
|     ) -> None:
 | |
|         """Test timeline iteration when target comment is not found"""
 | |
|         # Mock empty timeline
 | |
|         mock_gh_fetch_json_list.return_value = []
 | |
| 
 | |
|         events = list(iter_issue_timeline_until_comment("pytorch", "pytorch", 123, 999))
 | |
|         self.assertEqual(len(events), 0)
 | |
| 
 | |
|     @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | |
|     def test_get_commit_sha_at_comment_commit_after_comment(
 | |
|         self, mock_iter_timeline: Any, *args: Any
 | |
|     ) -> None:
 | |
|         """Test get_commit_sha_at_comment returns correct SHA after comment"""
 | |
|         mock_iter_timeline.return_value = [
 | |
|             {"event": "committed", "sha": "commit1"},
 | |
|             {"event": "committed", "sha": "commit2"},
 | |
|             {"event": "commented", "id": 100},
 | |
|             {"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
 | |
|         ]
 | |
|         pr = GitHubPR("pytorch", "pytorch", 77700)
 | |
|         sha = pr.get_commit_sha_at_comment(100)
 | |
|         self.assertEqual(sha, "commit2")
 | |
| 
 | |
|     @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | |
|     def test_get_commit_sha_at_comment_force_push_before_comment(
 | |
|         self, mock_iter_timeline: Any, *args: Any
 | |
|     ) -> None:
 | |
|         mock_iter_timeline.return_value = [
 | |
|             {"event": "committed", "sha": "commit1"},
 | |
|             {"event": "committed", "sha": "commit2"},
 | |
|             {"event": "head_ref_force_pushed", "commit_id": "commit3"},
 | |
|             {"event": "commented", "id": 100},
 | |
|         ]
 | |
|         pr = GitHubPR("pytorch", "pytorch", 77700)
 | |
|         sha = pr.get_commit_sha_at_comment(100)
 | |
|         self.assertEqual(sha, "commit3")
 | |
| 
 | |
|     @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | |
|     def test_get_commit_sha_at_comment_force_push_before_comment_legacy_mode(
 | |
|         self, mock_iter_timeline: Any, *args: Any
 | |
|     ) -> None:
 | |
|         mock_iter_timeline.return_value = [
 | |
|             {"event": "committed", "sha": "commit1"},
 | |
|             {"event": "committed", "sha": "commit2"},
 | |
|             {"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
 | |
|             {"event": "commented", "id": 100},
 | |
|         ]
 | |
|         pr = GitHubPR("pytorch", "pytorch", 77700)
 | |
|         sha = pr.get_commit_sha_at_comment(100)
 | |
|         self.assertEqual(sha, "commit3")
 | |
| 
 | |
|     @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | |
|     def test_get_commit_sha_at_comment_multiple_comments(
 | |
|         self, mock_iter_timeline: Any, *args: Any
 | |
|     ) -> None:
 | |
|         mock_iter_timeline.return_value = [
 | |
|             {"event": "committed", "sha": "commit1"},
 | |
|             {"event": "commented", "id": 100},
 | |
|             {"event": "committed", "sha": "commit2"},
 | |
|             {"event": "commented", "id": 200},
 | |
|             {"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
 | |
|             {"event": "commented", "id": 300},
 | |
|         ]
 | |
|         pr = GitHubPR("pytorch", "pytorch", 77700)
 | |
|         sha = pr.get_commit_sha_at_comment(200)
 | |
|         self.assertEqual(sha, "commit2")
 | |
|         sha = pr.get_commit_sha_at_comment(300)
 | |
|         self.assertEqual(sha, "commit3")
 | |
| 
 | |
|     @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | |
|     def test_get_commit_sha_at_comment_no_events(
 | |
|         self, mock_iter_timeline: Any, *args: Any
 | |
|     ) -> None:
 | |
|         mock_iter_timeline.return_value = [
 | |
|             {"event": "commented", "id": 100},
 | |
|             {"event": "labeled", "label": {"name": "test"}},
 | |
|         ]
 | |
|         pr = GitHubPR("pytorch", "pytorch", 77700)
 | |
|         sha = pr.get_commit_sha_at_comment(100)
 | |
|         self.assertIsNone(sha)
 | |
| 
 | |
|     @mock.patch("trymerge.iter_issue_timeline_until_comment")
 | |
|     def test_get_commit_sha_at_comment_exception(
 | |
|         self, mock_iter_timeline: Any, *args: Any
 | |
|     ) -> None:
 | |
|         mock_iter_timeline.side_effect = Exception("API error")
 | |
|         pr = GitHubPR("pytorch", "pytorch", 77700)
 | |
|         sha = pr.get_commit_sha_at_comment(100)
 | |
|         self.assertIsNone(sha)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |