mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Check commit order (#161560)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161560 Approved by: https://github.com/malfet ghstack dependencies: #161558, #161637
This commit is contained in:
committed by
PyTorch MergeBot
parent
b99a112688
commit
c8fa907e74
174
.github/scripts/test_trymerge.py
vendored
174
.github/scripts/test_trymerge.py
vendored
@ -27,6 +27,7 @@ from trymerge import (
|
||||
get_drci_classifications,
|
||||
gh_get_team_members,
|
||||
GitHubPR,
|
||||
iter_issue_timeline_until_comment,
|
||||
JobCheckState,
|
||||
main as trymerge_main,
|
||||
MandatoryChecksMissingError,
|
||||
@ -34,6 +35,8 @@ from trymerge import (
|
||||
RE_GHSTACK_DESC,
|
||||
read_merge_rules,
|
||||
remove_job_name_suffix,
|
||||
sha_from_committed_event,
|
||||
sha_from_force_push_after,
|
||||
validate_revert,
|
||||
)
|
||||
|
||||
@ -1138,5 +1141,176 @@ Pull Request resolved: https://github.com/pytorch/pytorch/pull/154394"""
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
|
||||
@mock.patch(
|
||||
"trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
|
||||
)
|
||||
class TestTimelineFunctions(TestCase):
|
||||
"""Tests for the new timeline-related functions"""
|
||||
|
||||
def test_sha_from_committed_event(self, *args: Any) -> None:
|
||||
"""Test extracting SHA from committed event"""
|
||||
# Based on actual GitHub API format - committed events have "sha" at top level
|
||||
event = {
|
||||
"event": "committed",
|
||||
"sha": "fb21ce932ded6670c918804a0d9151b773770a7c",
|
||||
}
|
||||
self.assertEqual(
|
||||
sha_from_committed_event(event), "fb21ce932ded6670c918804a0d9151b773770a7c"
|
||||
)
|
||||
|
||||
# Test with missing SHA
|
||||
event_no_sha = {"event": "committed"}
|
||||
self.assertIsNone(sha_from_committed_event(event_no_sha))
|
||||
|
||||
def test_sha_from_force_push_after(self, *args: Any) -> None:
|
||||
"""Test extracting SHA from force push event"""
|
||||
# NOTE: The current function doesn't handle the actual GitHub API format
|
||||
# Real force push events have "commit_id" at top level, but this function
|
||||
# looks for "after", "after_commit", "after_sha", or "head_sha" fields
|
||||
|
||||
# Test with the legacy format the current function handles
|
||||
event_legacy = {
|
||||
"event": "head_ref_force_pushed",
|
||||
"after": {"sha": "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e"},
|
||||
}
|
||||
self.assertEqual(
|
||||
sha_from_force_push_after(event_legacy),
|
||||
"ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
|
||||
)
|
||||
|
||||
# Test with current GitHub API format (should return None with current implementation)
|
||||
event_real_api = {
|
||||
"event": "head_ref_force_pushed",
|
||||
"commit_id": "ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
|
||||
}
|
||||
self.assertEqual(
|
||||
sha_from_force_push_after(event_real_api),
|
||||
"ef22bcbc54bb0f787e1e4ffd3d83df18fc407f5e",
|
||||
) # Current function doesn't handle commit_id
|
||||
|
||||
# Test with missing SHA
|
||||
event_no_sha = {"event": "head_ref_force_pushed"}
|
||||
self.assertIsNone(sha_from_force_push_after(event_no_sha))
|
||||
|
||||
@mock.patch("trymerge.gh_fetch_json_list")
|
||||
def test_iter_issue_timeline_until_comment(
|
||||
self, mock_gh_fetch_json_list: Any, *args: Any
|
||||
) -> None:
|
||||
"""Test timeline iteration until target comment"""
|
||||
# Mock timeline data based on actual GitHub API format
|
||||
timeline_data = [
|
||||
{"event": "commented", "id": 100, "body": "first comment"},
|
||||
{"event": "committed", "sha": "fb21ce932ded6670c918804a0d9151b773770a7c"},
|
||||
{"event": "commented", "id": 200, "body": "target comment"},
|
||||
{"event": "commented", "id": 300, "body": "after target"},
|
||||
]
|
||||
mock_gh_fetch_json_list.return_value = timeline_data
|
||||
|
||||
# Test iteration stops at target comment
|
||||
events = list(iter_issue_timeline_until_comment("pytorch", "pytorch", 123, 200))
|
||||
self.assertEqual(len(events), 3) # Should stop at target comment
|
||||
self.assertEqual(events[0]["event"], "commented")
|
||||
self.assertEqual(events[0]["id"], 100)
|
||||
self.assertEqual(events[1]["event"], "committed")
|
||||
self.assertEqual(events[1]["sha"], "fb21ce932ded6670c918804a0d9151b773770a7c")
|
||||
self.assertEqual(events[2]["event"], "commented")
|
||||
self.assertEqual(events[2]["id"], 200)
|
||||
|
||||
@mock.patch("trymerge.gh_fetch_json_list")
|
||||
def test_iter_issue_timeline_until_comment_not_found(
|
||||
self, mock_gh_fetch_json_list: Any, *args: Any
|
||||
) -> None:
|
||||
"""Test timeline iteration when target comment is not found"""
|
||||
# Mock empty timeline
|
||||
mock_gh_fetch_json_list.return_value = []
|
||||
|
||||
events = list(iter_issue_timeline_until_comment("pytorch", "pytorch", 123, 999))
|
||||
self.assertEqual(len(events), 0)
|
||||
|
||||
@mock.patch("trymerge.iter_issue_timeline_until_comment")
|
||||
def test_get_commit_sha_at_comment_commit_after_comment(
|
||||
self, mock_iter_timeline: Any, *args: Any
|
||||
) -> None:
|
||||
"""Test get_commit_sha_at_comment returns correct SHA after comment"""
|
||||
mock_iter_timeline.return_value = [
|
||||
{"event": "committed", "sha": "commit1"},
|
||||
{"event": "committed", "sha": "commit2"},
|
||||
{"event": "commented", "id": 100},
|
||||
{"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
|
||||
]
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
sha = pr.get_commit_sha_at_comment(100)
|
||||
self.assertEqual(sha, "commit2")
|
||||
|
||||
@mock.patch("trymerge.iter_issue_timeline_until_comment")
|
||||
def test_get_commit_sha_at_comment_force_push_before_comment(
|
||||
self, mock_iter_timeline: Any, *args: Any
|
||||
) -> None:
|
||||
mock_iter_timeline.return_value = [
|
||||
{"event": "committed", "sha": "commit1"},
|
||||
{"event": "committed", "sha": "commit2"},
|
||||
{"event": "head_ref_force_pushed", "commit_id": "commit3"},
|
||||
{"event": "commented", "id": 100},
|
||||
]
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
sha = pr.get_commit_sha_at_comment(100)
|
||||
self.assertEqual(sha, "commit3")
|
||||
|
||||
@mock.patch("trymerge.iter_issue_timeline_until_comment")
|
||||
def test_get_commit_sha_at_comment_force_push_before_comment_legacy_mode(
|
||||
self, mock_iter_timeline: Any, *args: Any
|
||||
) -> None:
|
||||
mock_iter_timeline.return_value = [
|
||||
{"event": "committed", "sha": "commit1"},
|
||||
{"event": "committed", "sha": "commit2"},
|
||||
{"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
|
||||
{"event": "commented", "id": 100},
|
||||
]
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
sha = pr.get_commit_sha_at_comment(100)
|
||||
self.assertEqual(sha, "commit3")
|
||||
|
||||
@mock.patch("trymerge.iter_issue_timeline_until_comment")
|
||||
def test_get_commit_sha_at_comment_multiple_comments(
|
||||
self, mock_iter_timeline: Any, *args: Any
|
||||
) -> None:
|
||||
mock_iter_timeline.return_value = [
|
||||
{"event": "committed", "sha": "commit1"},
|
||||
{"event": "commented", "id": 100},
|
||||
{"event": "committed", "sha": "commit2"},
|
||||
{"event": "commented", "id": 200},
|
||||
{"event": "head_ref_force_pushed", "after": {"sha": "commit3"}},
|
||||
{"event": "commented", "id": 300},
|
||||
]
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
sha = pr.get_commit_sha_at_comment(200)
|
||||
self.assertEqual(sha, "commit2")
|
||||
sha = pr.get_commit_sha_at_comment(300)
|
||||
self.assertEqual(sha, "commit3")
|
||||
|
||||
@mock.patch("trymerge.iter_issue_timeline_until_comment")
|
||||
def test_get_commit_sha_at_comment_no_events(
|
||||
self, mock_iter_timeline: Any, *args: Any
|
||||
) -> None:
|
||||
mock_iter_timeline.return_value = [
|
||||
{"event": "commented", "id": 100},
|
||||
{"event": "labeled", "label": {"name": "test"}},
|
||||
]
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
sha = pr.get_commit_sha_at_comment(100)
|
||||
self.assertIsNone(sha)
|
||||
|
||||
@mock.patch("trymerge.iter_issue_timeline_until_comment")
|
||||
def test_get_commit_sha_at_comment_exception(
|
||||
self, mock_iter_timeline: Any, *args: Any
|
||||
) -> None:
|
||||
mock_iter_timeline.side_effect = Exception("API error")
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
sha = pr.get_commit_sha_at_comment(100)
|
||||
self.assertIsNone(sha)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
131
.github/scripts/trymerge.py
vendored
131
.github/scripts/trymerge.py
vendored
@ -450,6 +450,63 @@ HAS_NO_CONNECTED_DIFF_TITLE = (
|
||||
IGNORABLE_FAILED_CHECKS_THESHOLD = 10
|
||||
|
||||
|
||||
def iter_issue_timeline_until_comment(
|
||||
org: str, repo: str, issue_number: int, target_comment_id: int, max_pages: int = 200
|
||||
) -> Any:
|
||||
"""
|
||||
Yield timeline entries in order until (and including) the entry whose id == target_comment_id
|
||||
for a 'commented' event. Stops once the target comment is encountered.
|
||||
"""
|
||||
page = 1
|
||||
|
||||
while page <= max_pages:
|
||||
url = (
|
||||
f"https://api.github.com/repos/{org}/{repo}/issues/{issue_number}/timeline"
|
||||
)
|
||||
params = {"per_page": 100, "page": page}
|
||||
|
||||
batch = gh_fetch_json_list(url, params)
|
||||
|
||||
if not batch:
|
||||
return
|
||||
for ev in batch:
|
||||
# The target is the issue comment row with event == "commented" and id == issue_comment_id
|
||||
if ev.get("event") == "commented" and ev.get("id") == target_comment_id:
|
||||
yield ev # nothing in the timeline after this matters, so stop early
|
||||
return
|
||||
yield ev
|
||||
if len(batch) < 100:
|
||||
return
|
||||
page += 1
|
||||
|
||||
# If we got here without finding the comment, then we either hit a bug or some github PR
|
||||
# has a _really_ long timeline.
|
||||
# The max # of pages found on any pytorch/pytorch PR at the time of this change was 41
|
||||
raise RuntimeError(
|
||||
f"Could not find a merge commit in the first {max_pages} pages of the timeline at url {url}."
|
||||
f"This is most likely a bug, please report it to the @pytorch/pytorch-dev-infra team."
|
||||
)
|
||||
|
||||
|
||||
def sha_from_committed_event(ev: dict[str, Any]) -> Optional[str]:
|
||||
"""Extract SHA from committed event in timeline"""
|
||||
return ev.get("sha")
|
||||
|
||||
|
||||
def sha_from_force_push_after(ev: dict[str, Any]) -> Optional[str]:
|
||||
"""Extract SHA from force push event in timeline"""
|
||||
# The current GitHub API format
|
||||
commit_id = ev.get("commit_id")
|
||||
if commit_id:
|
||||
return str(commit_id)
|
||||
|
||||
# Legacy format
|
||||
after = ev.get("after") or ev.get("after_commit") or {}
|
||||
if isinstance(after, dict):
|
||||
return after.get("sha") or after.get("oid")
|
||||
return ev.get("after_sha") or ev.get("head_sha")
|
||||
|
||||
|
||||
def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
|
||||
rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
|
||||
return rc["data"]["repository"]["pullRequest"]
|
||||
@ -843,6 +900,44 @@ class GitHubPR:
|
||||
def get_commit_count(self) -> int:
|
||||
return int(self.info["commits_with_authors"]["totalCount"])
|
||||
|
||||
def get_commit_sha_at_comment(self, comment_id: int) -> Optional[str]:
|
||||
"""
|
||||
Get the PR head commit SHA that was present when a specific comment was posted.
|
||||
This ensures we only merge the state of the PR at the time the merge command was issued,
|
||||
not any subsequent commits that may have been pushed after.
|
||||
|
||||
Returns None if no head-changing events found before the comment or if the comment was not found.
|
||||
"""
|
||||
head = None
|
||||
|
||||
try:
|
||||
for event in iter_issue_timeline_until_comment(
|
||||
self.org, self.project, self.pr_num, comment_id
|
||||
):
|
||||
etype = event.get("event")
|
||||
if etype == "committed":
|
||||
sha = sha_from_committed_event(event)
|
||||
if sha:
|
||||
head = sha
|
||||
print(f"Timeline: Found commit event for SHA {sha}")
|
||||
elif etype == "head_ref_force_pushed":
|
||||
sha = sha_from_force_push_after(event)
|
||||
if sha:
|
||||
head = sha
|
||||
print(f"Timeline: Found force push event for SHA {sha}")
|
||||
elif etype == "commented":
|
||||
if event.get("id") == comment_id:
|
||||
print(f"Timeline: Found final comment with sha {sha}")
|
||||
return head
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Failed to reconstruct timeline for comment {comment_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
print(f"Did not find comment with id {comment_id} in the PR timeline")
|
||||
return None
|
||||
|
||||
def get_pr_creator_login(self) -> str:
|
||||
return cast(str, self.info["author"]["login"])
|
||||
|
||||
@ -1234,11 +1329,14 @@ class GitHubPR:
|
||||
skip_all_rule_checks: bool = False,
|
||||
) -> list["GitHubPR"]:
|
||||
"""
|
||||
:param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
|
||||
:param skip_all_rule_checks: If true, skips all rule checks on ghstack PRs, useful for dry-running merge locally
|
||||
"""
|
||||
branch_to_merge_into = self.default_branch() if branch is None else branch
|
||||
if repo.current_branch() != branch_to_merge_into:
|
||||
repo.checkout(branch_to_merge_into)
|
||||
|
||||
# It's okay to skip the commit SHA check for ghstack PRs since
|
||||
# authoring requires write access to the repo.
|
||||
if self.is_ghstack_pr():
|
||||
return self.merge_ghstack_into(
|
||||
repo,
|
||||
@ -1249,14 +1347,41 @@ class GitHubPR:
|
||||
|
||||
msg = self.gen_commit_message()
|
||||
pr_branch_name = f"__pull-request-{self.pr_num}__init__"
|
||||
repo.fetch(self.last_commit_sha(), pr_branch_name)
|
||||
|
||||
# Determine which commit SHA to merge
|
||||
commit_to_merge = None
|
||||
if not comment_id:
|
||||
raise ValueError("Must provide --comment-id when merging regular PRs")
|
||||
|
||||
# Get the commit SHA that was present when the comment was made
|
||||
commit_to_merge = self.get_commit_sha_at_comment(comment_id)
|
||||
if not commit_to_merge:
|
||||
raise RuntimeError(
|
||||
f"Could not find commit that was pushed before comment {comment_id}"
|
||||
)
|
||||
|
||||
# Validate that this commit is the latest commit on the PR
|
||||
latest_commit = self.last_commit_sha()
|
||||
if commit_to_merge != latest_commit:
|
||||
raise RuntimeError(
|
||||
f"Commit {commit_to_merge} was HEAD when comment {comment_id} was posted "
|
||||
f"but now the latest commit on the PR is {latest_commit}. "
|
||||
f"Please re-issue the merge command to merge the latest commit."
|
||||
)
|
||||
|
||||
print(f"Merging commit {commit_to_merge} locally")
|
||||
|
||||
repo.fetch(commit_to_merge, pr_branch_name)
|
||||
repo._run_git("merge", "--squash", pr_branch_name)
|
||||
repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg)
|
||||
|
||||
# Did the PR change since we started the merge?
|
||||
pulled_sha = repo.show_ref(pr_branch_name)
|
||||
latest_pr_status = GitHubPR(self.org, self.project, self.pr_num)
|
||||
if pulled_sha != latest_pr_status.last_commit_sha():
|
||||
if (
|
||||
pulled_sha != latest_pr_status.last_commit_sha()
|
||||
or pulled_sha != commit_to_merge
|
||||
):
|
||||
raise RuntimeError(
|
||||
"PR has been updated since CI checks last passed. Please rerun the merge command."
|
||||
)
|
||||
|
Reference in New Issue
Block a user