[GHF] Add support for new style stacks (#116873)

Where base stack targets default branch, rather than base. But as
default branch is likely to advance, since PR was made, search for
mergebase before determining whether `base`..`head` are in sync with `orig` branch
Also, rather than hardcode default branch name, fetch it from `GitHubPR.default_branch()`

Test Plan: https://github.com/malfet/deleteme/pull/77

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116873
Approved by: https://github.com/ezyang
This commit is contained in:
Nikita Shulga
2024-01-05 20:32:24 +00:00
committed by PyTorch MergeBot
parent 71d8fe690f
commit 0f0020d76f
2 changed files with 25 additions and 5 deletions

View File

@ -397,13 +397,28 @@ def _shasum(value: str) -> str:
return m.hexdigest() return m.hexdigest()
def are_ghstack_branches_in_sync(repo: GitRepo, head_ref: str) -> bool: def is_commit_hash(ref: str) -> bool:
"True if ref is hexadecimal number, else false"
try:
int(ref, 16)
except ValueError:
return False
return True
def are_ghstack_branches_in_sync(
repo: GitRepo, head_ref: str, base_ref: Optional[str] = None
) -> bool:
"""Checks that diff between base and head is the same as diff between orig and its parent""" """Checks that diff between base and head is the same as diff between orig and its parent"""
orig_ref = re.sub(r"/head$", "/orig", head_ref) orig_ref = re.sub(r"/head$", "/orig", head_ref)
base_ref = re.sub(r"/head$", "/base", head_ref) if base_ref is None:
base_ref = re.sub(r"/head$", "/base", head_ref)
orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}")) orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}"))
head_diff_sha = _shasum( head_diff_sha = _shasum(
repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}") repo.diff(
base_ref if is_commit_hash(base_ref) else f"{repo.remote}/{base_ref}",
f"{repo.remote}/{head_ref}",
)
) )
return orig_diff_sha == head_diff_sha return orig_diff_sha == head_diff_sha

View File

@ -674,10 +674,15 @@ def get_ghstack_prs(
for stacked_pr, rev in entire_stack: for stacked_pr, rev in entire_stack:
if stacked_pr.is_closed(): if stacked_pr.is_closed():
continue continue
if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref()): base_ref = stacked_pr.base_ref()
if base_ref == pr.default_branch():
base_ref = repo.get_merge_base(
f"{repo.remote}/{base_ref}", f"{repo.remote}/{stacked_pr.head_ref()}"
)
if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref(), base_ref):
raise RuntimeError( raise RuntimeError(
f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on " f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on "
+ f"branch {orig_ref} that would be merged into main. " + f"branch {stacked_pr.get_ghstack_orig_ref()} that would be merged into {stacked_pr.default_branch()}. "
+ "This usually happens because there is a non ghstack change in the PR. " + "This usually happens because there is a non ghstack change in the PR. "
+ f"Please sync them and try again (ex. make the changes on {orig_ref} and run ghstack)." + f"Please sync them and try again (ex. make the changes on {orig_ref} and run ghstack)."
) )