mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136359 Approved by: https://github.com/albanD
66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
from tools.testing.target_determination.heuristics.interface import (
|
|
HeuristicInterface,
|
|
TestPrioritizations,
|
|
)
|
|
from tools.testing.target_determination.heuristics.utils import (
|
|
get_git_commit_info,
|
|
get_issue_or_pr_body,
|
|
get_pr_number,
|
|
)
|
|
from tools.testing.test_run import TestRun
|
|
|
|
|
|
# This heuristic searches the PR body and commit titles, as well as issues/PRs
|
|
# mentioned in the PR body/commit title for test names (search depth of 1) and
|
|
# gives the test a rating of 1. For example, if I mention "test_foo" in the PR
|
|
# body, test_foo will be rated 1. If I mention #123 in the PR body, and #123
|
|
# mentions "test_foo", test_foo will be rated 1.
|
|
class MentionedInPR(HeuristicInterface):
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
super().__init__(**kwargs)
|
|
|
|
def _search_for_linked_issues(self, s: str) -> list[str]:
|
|
return re.findall(r"#(\d+)", s) + re.findall(r"/pytorch/pytorch/.*/(\d+)", s)
|
|
|
|
def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
|
|
try:
|
|
commit_messages = get_git_commit_info()
|
|
except Exception as e:
|
|
print(f"Can't get commit info due to {e}")
|
|
commit_messages = ""
|
|
try:
|
|
pr_number = get_pr_number()
|
|
if pr_number is not None:
|
|
pr_body = get_issue_or_pr_body(pr_number)
|
|
else:
|
|
pr_body = ""
|
|
except Exception as e:
|
|
print(f"Can't get PR body due to {e}")
|
|
pr_body = ""
|
|
|
|
# Search for linked issues or PRs
|
|
linked_issue_bodies: list[str] = []
|
|
for issue in self._search_for_linked_issues(
|
|
commit_messages
|
|
) + self._search_for_linked_issues(pr_body):
|
|
try:
|
|
linked_issue_bodies.append(get_issue_or_pr_body(int(issue)))
|
|
except Exception:
|
|
pass
|
|
|
|
mentioned = []
|
|
for test in tests:
|
|
if (
|
|
test in commit_messages
|
|
or test in pr_body
|
|
or any(test in body for body in linked_issue_bodies)
|
|
):
|
|
mentioned.append(test)
|
|
|
|
return TestPrioritizations(tests, {TestRun(test): 1 for test in mentioned})
|