diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py index aa64fe15387e..f97c2f6c4403 100644 --- a/.github/scripts/gitutils.py +++ b/.github/scripts/gitutils.py @@ -273,6 +273,11 @@ class GitRepo: def amend_commit_message(self, msg: str) -> None: self._run_git("commit", "--amend", "-m", msg) + def diff(self, from_ref: str, to_ref: Optional[str] = None) -> str: + if to_ref is None: + return self._run_git("diff", f"{from_ref}^!") + return self._run_git("diff", f"{from_ref}..{to_ref}") + def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo: path = tempfile.mkdtemp() @@ -331,3 +336,18 @@ def patterns_to_regex(allowed_patterns: List[str]) -> Any: rc += c rc += ")" return re.compile(rc) + +def _shasum(value: str) -> str: + import hashlib + m = hashlib.sha256() + m.update(value.encode("utf-8")) + return m.hexdigest() + + +def are_ghstack_branches_in_sync(repo: GitRepo, head_ref: str) -> bool: + """ Checks that diff between base and head is the same as diff between orig and its parent """ + orig_ref = re.sub(r'/head$', '/orig', head_ref) + base_ref = re.sub(r'/head$', '/base', head_ref) + orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}")) + head_diff_sha = _shasum(repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}")) + return orig_diff_sha == head_diff_sha diff --git a/.github/scripts/test_gitutils.py b/.github/scripts/test_gitutils.py index 78696771d993..9987cdea9781 100644 --- a/.github/scripts/test_gitutils.py +++ b/.github/scripts/test_gitutils.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 -from gitutils import PeekableIterator, patterns_to_regex -from unittest import TestCase, main +from gitutils import PeekableIterator, patterns_to_regex, GitRepo, are_ghstack_branches_in_sync, _shasum +from unittest import TestCase, main, SkipTest +from pathlib import Path + + +BASE_DIR = Path(__file__).parent + class TestPeekableIterator(TestCase): def test_iterator(self, input_: str = "abcdef") -> None: @@ -35,5 +40,34 @@ class TestPattern(TestCase): self.assertTrue(patterns_re.match(filename)) +class TestGitRepo(TestCase): + def setUp(self) -> None: + repo_dir = BASE_DIR.parent.parent.absolute() + if not (repo_dir / ".git").is_dir(): + raise SkipTest("Can't find git directory, make sure to run this test on real repo checkout") + self.repo = GitRepo(str(repo_dir)) + + def _skip_if_ref_does_not_exist(self, ref: str) -> None: + """ Skip test if ref is missing as stale branches are deleted with time """ + try: + self.repo.show_ref(ref) + except RuntimeError as e: + raise SkipTest(f"Can't find head ref {ref} due to {str(e)}") from e + + def test_compute_diff(self) -> None: + diff = self.repo.diff("HEAD") + sha = _shasum(diff) + self.assertEqual(len(sha), 64) + + def test_ghstack_branches_in_sync(self) -> None: + head_ref = "gh/SS-JIA/206/head" + self._skip_if_ref_does_not_exist(head_ref) + self.assertTrue(are_ghstack_branches_in_sync(self.repo, head_ref)) + + def test_ghstack_branches_not_in_sync(self) -> None: + head_ref = "gh/clee2000/1/head" + self._skip_if_ref_does_not_exist(head_ref) + self.assertFalse(are_ghstack_branches_in_sync(self.repo, head_ref)) + if __name__ == '__main__': main() diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 182d39e0f5de..f8a59d905c76 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -29,6 +29,7 @@ from pathlib import Path from gitutils import ( GitRepo, + are_ghstack_branches_in_sync, get_git_remote_name, get_git_repo_dir, patterns_to_regex, @@ -619,6 +620,7 @@ def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) - return False return comment.author_login == "facebook-github-bot" + def get_ghstack_prs(repo: GitRepo, pr: "GitHubPR") -> List[Tuple["GitHubPR", str]]: ''' Get the open PRs in the stack that are below this PR. Throws error if any of the PRs are out of sync. @@ -646,9 +648,7 @@ def get_ghstack_prs(repo: GitRepo, pr: "GitHubPR") -> List[Tuple["GitHubPR", str entire_stack.append((pr, rev)) for stacked_pr, rev in entire_stack: - commit_sha = stacked_pr.last_commit()['oid'] - tree_sha = repo._run_git("rev-parse", commit_sha + "^{tree}") - if tree_sha not in repo.commit_message(rev): + if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref()): raise RuntimeError( f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on " + f"branch {orig_ref} that would be merged into master. " +