[GHF][mergebot] record ghstack dependencies in the commit message (#105251)

Currently all information about the dependencies of ghstack PRs (e.g. #105010) is stripped away:
c984885809/.github/scripts/trymerge.py (L1077-L1078)

This PR adds this information back in a more compact form. All dependencies (PR numbers) of each PR in ghstack are recorded.

The resulting commit message will look like this (the last line is new):

> Mock title (#123)
>
> Mock body text
> Pull Request resolved: https://github.com/pytorch/pytorch/pull/123
> Approved by: https://github.com/Approver1, https://github.com/Approver2
> ghstack dependencies: #1, #2

---

### Testing

Unit tests.

---

### Note Re: `# type: ignore[assignment]` in unit tests.

I did my due diligence to find alternatives. Unfortunately mypy [doesn't](https://github.com/python/mypy/issues/6713) support this [way of patching methods](https://docs.python.org/3/library/unittest.mock-examples.html#mock-patching-methods), and the alternatives are either extremely verbose or don't work for this case. I decided it's not worth the effort (since the problem is limited only to the unit test).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105251
Approved by: https://github.com/huydhn
This commit is contained in:
Ivan Zaitsev
2023-07-29 20:32:10 +00:00
committed by PyTorch MergeBot
parent 0ee3b84021
commit d2aa3f5fa9
4 changed files with 2520 additions and 8 deletions

2392
.github/scripts/gql_mocks.json generated vendored

File diff suppressed because it is too large Load Diff

View File

@ -42650,5 +42650,6 @@
"failure_captures": null,
"steps": 0
}
]
],
"f9a10a148958be8bd6474b8ab36da3f0aedd6a70 dc0bf418c1045b64911c56a27e685d2bb4ffa5d3": []
}

View File

@ -836,5 +836,103 @@ class TestBypassFailures(TestCase):
)
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
class TestGitHubPRGhstackDependencies2(TestCase):
def test_pr_dependencies(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 106068)
msg = pr.gen_commit_message(filter_ghstack=True)
assert msg == (
"[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\nDifferential Revision: ["
"D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\nPull Request resolved: "
"https://github.com/pytorch/pytorch/pull/106068\nApproved by: \n"
)
def test_pr_dependencies_ghstack(self, *args: Any) -> None:
pr0 = GitHubPR("pytorch", "pytorch", 106032)
pr1 = GitHubPR("pytorch", "pytorch", 106033)
pr2 = GitHubPR("pytorch", "pytorch", 106034)
pr = GitHubPR("pytorch", "pytorch", 106068)
msg = pr.gen_commit_message(filter_ghstack=True, ghstack_deps=[pr0, pr1, pr2])
assert msg == (
"[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\nDifferential Revision: ["
"D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\nPull Request resolved: "
"https://github.com/pytorch/pytorch/pull/106068\nApproved by: \n"
"ghstack dependencies: #106032, #106033, #106034\n"
)
@mock.patch("trymerge.read_merge_rules")
@mock.patch("trymerge.GitRepo")
@mock.patch("trymerge.get_ghstack_prs")
def test_merge_ghstack_into(
self,
mock_get_ghstack_prs: mock.MagicMock,
mock_repo: mock.MagicMock,
mock_merge_rules: mock.MagicMock,
*args: Any,
) -> None:
"""
Test that the merge_ghstack_into method works correctly
"""
pr0 = GitHubPR("pytorch", "pytorch", 106032)
pr1 = GitHubPR("pytorch", "pytorch", 106033)
pr2 = GitHubPR("pytorch", "pytorch", 106034)
pr = GitHubPR("pytorch", "pytorch", 106068)
# note: in reverse order (e.g. self.pr is the last commit, top of the stack)
mock_get_ghstack_prs.return_value = [
(pr0, "rev0"),
(pr1, "rev1"),
(pr2, "rev2"),
(pr, "rev123"),
]
mock_merge_rules.return_value = [
MergeRule(
"Mock title", patterns=["*"], approved_by=[], mandatory_checks_name=None
)
]
mock_repo.cherry_pick.return_value = None
mock_repo.amend_commit_message.return_value = None
# Call the method under test
res = pr.merge_ghstack_into(mock_repo, True)
self.assertEqual(res, [pr2, pr])
mock_repo.cherry_pick.assert_any_call("rev2")
mock_repo.cherry_pick.assert_any_call("rev123")
assert mock.call("rev1") not in mock_repo.cherry_pick.call_args_list
# Verify the first call
message = mock_repo.amend_commit_message.call_args_list[0].args[0]
prefix = (
"[FSDP] Optimize away intermediate `div_` for HSDP (#106034)\n\n\r\n"
"### Background: Gradient Pre-Divide"
)
suffix = (
"\nPull Request resolved: https://github.com/pytorch/pytorch/pull/106034\nApproved by: \nghstack "
"dependencies: #106032, #106033\n"
)
assert message.startswith(prefix)
assert message.endswith(suffix)
# Verify the second call
mock_repo.amend_commit_message.assert_any_call(
"[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\n"
"Differential Revision: ["
"D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\n"
"Pull Request resolved: "
"https://github.com/pytorch/pytorch/pull/106068\n"
"Approved by: \n"
"ghstack dependencies: #106032, #106033, #106034\n"
)
if __name__ == "__main__":
main()

View File

@ -622,9 +622,12 @@ def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -
return comment.author_login == "facebook-github-bot"
def get_ghstack_prs(repo: GitRepo, pr: "GitHubPR") -> List[Tuple["GitHubPR", str]]:
def get_ghstack_prs(
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
) -> List[Tuple["GitHubPR", str]]:
"""
Get the open PRs in the stack that are below this PR. Throws error if any of the PRs are out of sync.
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
@:param open_only: Only return open PRs
"""
assert pr.is_ghstack_pr()
entire_stack: List[Tuple[GitHubPR, str]] = []
@ -645,7 +648,7 @@ def get_ghstack_prs(repo: GitRepo, pr: "GitHubPR") -> List[Tuple["GitHubPR", str
stacked_pr_num = int(m.group("number"))
if stacked_pr_num != pr.pr_num:
stacked_pr = GitHubPR(pr.org, pr.project, stacked_pr_num)
if stacked_pr.is_closed():
if open_only and stacked_pr.is_closed():
print(
f"Skipping {idx+1} of {len(rev_list)} PR (#{stacked_pr_num}) as its already been merged"
)
@ -655,6 +658,8 @@ def get_ghstack_prs(repo: GitRepo, pr: "GitHubPR") -> List[Tuple["GitHubPR", str
entire_stack.append((pr, rev))
for stacked_pr, rev in entire_stack:
if stacked_pr.is_closed():
continue
if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref()):
raise RuntimeError(
f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on "
@ -1037,9 +1042,18 @@ class GitHubPR:
comment_id: Optional[int] = None,
) -> List["GitHubPR"]:
assert self.is_ghstack_pr()
ghstack_prs = get_ghstack_prs(repo, self) # raises error if out of sync
ghstack_prs = get_ghstack_prs(
repo, self, open_only=False
) # raises error if out of sync
pr_dependencies = []
for pr, rev in ghstack_prs:
commit_msg = pr.gen_commit_message(filter_ghstack=True)
if pr.is_closed():
pr_dependencies.append(pr)
continue
commit_msg = pr.gen_commit_message(
filter_ghstack=True, ghstack_deps=pr_dependencies
)
if pr.pr_num != self.pr_num:
# Raises exception if matching rule is not found
find_matching_merge_rule(
@ -1050,9 +1064,14 @@ class GitHubPR:
)
repo.cherry_pick(rev)
repo.amend_commit_message(commit_msg)
return [x for x, _ in ghstack_prs]
pr_dependencies.append(pr)
return [x for x, _ in ghstack_prs if not x.is_closed()]
def gen_commit_message(self, filter_ghstack: bool = False) -> str:
def gen_commit_message(
self,
filter_ghstack: bool = False,
ghstack_deps: Optional[List["GitHubPR"]] = None,
) -> str:
"""Fetches title and body from PR description
adds reviewed by, pull request resolved and optionally
filters out ghstack info"""
@ -1068,6 +1087,8 @@ class GitHubPR:
msg += msg_body
msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
msg += f"Approved by: {approved_by_urls}\n"
if ghstack_deps:
msg += f"ghstack dependencies: {', '.join([f'#{pr.pr_num}' for pr in ghstack_deps])}\n"
return msg
def add_numbered_label(self, label_base: str) -> None: