mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[GHF] Fix ghstack branches in sync logic (#93298)
Test plan: ```python from git_utils import are_ghstack_branches_in_sync,GitRepo repo=GitRepo("/Users/nshulga/git/pytorch/pytorch") are_ghstack_branches_in_sync(repo, "gh/SS-JIA/206/head") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93298 Approved by: https://github.com/clee2000, https://github.com/ZainRizvi
This commit is contained in:
committed by
PyTorch MergeBot
parent
54056c1705
commit
7a621c443b
20
.github/scripts/gitutils.py
vendored
20
.github/scripts/gitutils.py
vendored
@ -273,6 +273,11 @@ class GitRepo:
|
||||
def amend_commit_message(self, msg: str) -> None:
|
||||
self._run_git("commit", "--amend", "-m", msg)
|
||||
|
||||
def diff(self, from_ref: str, to_ref: Optional[str] = None) -> str:
|
||||
if to_ref is None:
|
||||
return self._run_git("diff", f"{from_ref}^!")
|
||||
return self._run_git("diff", f"{from_ref}..{to_ref}")
|
||||
|
||||
|
||||
def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo:
|
||||
path = tempfile.mkdtemp()
|
||||
@ -331,3 +336,18 @@ def patterns_to_regex(allowed_patterns: List[str]) -> Any:
|
||||
rc += c
|
||||
rc += ")"
|
||||
return re.compile(rc)
|
||||
|
||||
def _shasum(value: str) -> str:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
m.update(value.encode("utf-8"))
|
||||
return m.hexdigest()
|
||||
|
||||
|
||||
def are_ghstack_branches_in_sync(repo: GitRepo, head_ref: str) -> bool:
|
||||
""" Checks that diff between base and head is the same as diff between orig and its parent """
|
||||
orig_ref = re.sub(r'/head$', '/orig', head_ref)
|
||||
base_ref = re.sub(r'/head$', '/base', head_ref)
|
||||
orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}"))
|
||||
head_diff_sha = _shasum(repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}"))
|
||||
return orig_diff_sha == head_diff_sha
|
||||
|
38
.github/scripts/test_gitutils.py
vendored
38
.github/scripts/test_gitutils.py
vendored
@ -1,6 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
from gitutils import PeekableIterator, patterns_to_regex
|
||||
from unittest import TestCase, main
|
||||
from gitutils import PeekableIterator, patterns_to_regex, GitRepo, are_ghstack_branches_in_sync, _shasum
|
||||
from unittest import TestCase, main, SkipTest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class TestPeekableIterator(TestCase):
|
||||
def test_iterator(self, input_: str = "abcdef") -> None:
|
||||
@ -35,5 +40,34 @@ class TestPattern(TestCase):
|
||||
self.assertTrue(patterns_re.match(filename))
|
||||
|
||||
|
||||
class TestGitRepo(TestCase):
|
||||
def setUp(self) -> None:
|
||||
repo_dir = BASE_DIR.parent.parent.absolute()
|
||||
if not (repo_dir / ".git").is_dir():
|
||||
raise SkipTest("Can't find git directory, make sure to run this test on real repo checkout")
|
||||
self.repo = GitRepo(str(repo_dir))
|
||||
|
||||
def _skip_if_ref_does_not_exist(self, ref: str) -> None:
|
||||
""" Skip test if ref is missing as stale branches are deleted with time """
|
||||
try:
|
||||
self.repo.show_ref(ref)
|
||||
except RuntimeError as e:
|
||||
raise SkipTest(f"Can't find head ref {ref} due to {str(e)}") from e
|
||||
|
||||
def test_compute_diff(self) -> None:
|
||||
diff = self.repo.diff("HEAD")
|
||||
sha = _shasum(diff)
|
||||
self.assertEqual(len(sha), 64)
|
||||
|
||||
def test_ghstack_branches_in_sync(self) -> None:
|
||||
head_ref = "gh/SS-JIA/206/head"
|
||||
self._skip_if_ref_does_not_exist(head_ref)
|
||||
self.assertTrue(are_ghstack_branches_in_sync(self.repo, head_ref))
|
||||
|
||||
def test_ghstack_branches_not_in_sync(self) -> None:
|
||||
head_ref = "gh/clee2000/1/head"
|
||||
self._skip_if_ref_does_not_exist(head_ref)
|
||||
self.assertFalse(are_ghstack_branches_in_sync(self.repo, head_ref))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
6
.github/scripts/trymerge.py
vendored
6
.github/scripts/trymerge.py
vendored
@ -29,6 +29,7 @@ from pathlib import Path
|
||||
|
||||
from gitutils import (
|
||||
GitRepo,
|
||||
are_ghstack_branches_in_sync,
|
||||
get_git_remote_name,
|
||||
get_git_repo_dir,
|
||||
patterns_to_regex,
|
||||
@ -619,6 +620,7 @@ def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -
|
||||
return False
|
||||
return comment.author_login == "facebook-github-bot"
|
||||
|
||||
|
||||
def get_ghstack_prs(repo: GitRepo, pr: "GitHubPR") -> List[Tuple["GitHubPR", str]]:
|
||||
'''
|
||||
Get the open PRs in the stack that are below this PR. Throws error if any of the PRs are out of sync.
|
||||
@ -646,9 +648,7 @@ def get_ghstack_prs(repo: GitRepo, pr: "GitHubPR") -> List[Tuple["GitHubPR", str
|
||||
entire_stack.append((pr, rev))
|
||||
|
||||
for stacked_pr, rev in entire_stack:
|
||||
commit_sha = stacked_pr.last_commit()['oid']
|
||||
tree_sha = repo._run_git("rev-parse", commit_sha + "^{tree}")
|
||||
if tree_sha not in repo.commit_message(rev):
|
||||
if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref()):
|
||||
raise RuntimeError(
|
||||
f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on " +
|
||||
f"branch {orig_ref} that would be merged into master. " +
|
||||
|
Reference in New Issue
Block a user