mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
TD Heuristic for tests mentioned in PR body, less verbose TD printing (#120621)
Move tests that are mentioned in PR body or commit message to front. Also attempts to find any issues/PRs mentioned in the PR body and search for those too (ex if you link a disable issue and that issue contains the test file that it was failing on) looking for: dynamo/test_export_mutations Also removes some printed information in TD Pull Request resolved: https://github.com/pytorch/pytorch/pull/120621 Approved by: https://github.com/osalpekar
This commit is contained in:
committed by
PyTorch MergeBot
parent
c7a65f58b0
commit
63ec5cd158
2
.github/workflows/target_determination.yml
vendored
2
.github/workflows/target_determination.yml
vendored
@ -39,6 +39,7 @@ jobs:
|
||||
id: td
|
||||
continue-on-error: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_REPOSITORY: ${{ github.repository }}
|
||||
GITHUB_WORKFLOW: ${{ github.workflow }}
|
||||
GITHUB_JOB: ${{ github.job }}
|
||||
@ -47,6 +48,7 @@ jobs:
|
||||
GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }}
|
||||
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
run: |
|
||||
python3 -m pip install boto3==1.19.12
|
||||
python3 tools/testing/do_target_determination_for_s3.py
|
||||
|
@ -20,6 +20,7 @@ def get_test_prioritizations(
|
||||
new_rankings: TestPrioritizations = heuristic.get_prediction_confidence(tests)
|
||||
aggregated_results.add_heuristic_results(heuristic, new_rankings)
|
||||
|
||||
print(new_rankings.get_info_str(), file=file)
|
||||
print(f"Results from {heuristic.__class__.__name__}")
|
||||
print(new_rankings.get_info_str(verbose=False), file=file)
|
||||
|
||||
return aggregated_results
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
from tools.testing.target_determination.heuristics.correlated_with_historical_failures import (
|
||||
CorrelatedWithHistoricalFailures,
|
||||
@ -16,6 +16,7 @@ from tools.testing.target_determination.heuristics.interface import (
|
||||
HeuristicInterface as HeuristicInterface,
|
||||
TestPrioritizations as TestPrioritizations,
|
||||
)
|
||||
from tools.testing.target_determination.heuristics.mentioned_in_pr import MentionedInPR
|
||||
|
||||
from tools.testing.target_determination.heuristics.previously_failed_in_pr import (
|
||||
PreviouslyFailedInPR,
|
||||
@ -27,6 +28,7 @@ from tools.testing.target_determination.heuristics.profiling import Profiling
|
||||
HEURISTICS: List[HeuristicInterface] = [
|
||||
PreviouslyFailedInPR(),
|
||||
EditedByPR(),
|
||||
MentionedInPR(),
|
||||
HistoricalClassFailurCorrelation(trial_mode=True),
|
||||
CorrelatedWithHistoricalFailures(),
|
||||
HistorialEditedFiles(),
|
||||
|
@ -118,13 +118,15 @@ class TestPrioritizations:
|
||||
tests = [x[1] for x in self._traverse_scores()]
|
||||
return tests[: n * len(tests) // 100], tests[n * len(tests) // 100 :]
|
||||
|
||||
def get_info_str(self) -> str:
|
||||
def get_info_str(self, verbose: bool = True) -> str:
|
||||
info = ""
|
||||
|
||||
for score, test in self._traverse_scores():
|
||||
info += f"{test} ({score})\n"
|
||||
if not verbose and score == 0:
|
||||
continue
|
||||
info += f" {test} ({score})\n"
|
||||
|
||||
return info.strip()
|
||||
return info.rstrip()
|
||||
|
||||
def print_info(self) -> None:
|
||||
print(self.get_info_str())
|
||||
|
@ -0,0 +1,66 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List
|
||||
|
||||
from tools.testing.target_determination.heuristics.interface import (
|
||||
HeuristicInterface,
|
||||
TestPrioritizations,
|
||||
)
|
||||
from tools.testing.target_determination.heuristics.utils import (
|
||||
get_git_commit_info,
|
||||
get_issue_or_pr_body,
|
||||
)
|
||||
from tools.testing.test_run import TestRun
|
||||
|
||||
|
||||
# This heuristic searches the PR body and commit titles, as well as issues/PRs
|
||||
# mentioned in the PR body/commit title for test names (search depth of 1) and
|
||||
# gives the test a rating of 1. For example, if I mention "test_foo" in the PR
|
||||
# body, test_foo will be rated 1. If I mention #123 in the PR body, and #123
|
||||
# mentions "test_foo", test_foo will be rated 1.
|
||||
class MentionedInPR(HeuristicInterface):
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _search_for_linked_issues(self, s: str) -> List[str]:
|
||||
return re.findall(r"#(\d+)", s) + re.findall(r"/pytorch/pytorch/.*/(\d+)", s)
|
||||
|
||||
def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations:
|
||||
try:
|
||||
commit_messages = get_git_commit_info()
|
||||
except Exception as e:
|
||||
print(f"Can't get commit info due to {e}")
|
||||
commit_messages = ""
|
||||
try:
|
||||
pr_number = os.environ.get("PR_NUMBER", "")
|
||||
if pr_number == "":
|
||||
re_match = re.match(
|
||||
r"^refs/tags/.*/(\d+)$", os.environ.get("GITHUB_REF", "")
|
||||
)
|
||||
if re_match is not None:
|
||||
pr_number = re_match.group(1)
|
||||
pr_body = get_issue_or_pr_body(int(pr_number))
|
||||
except Exception as e:
|
||||
print(f"Can't get PR body due to {e}")
|
||||
pr_body = ""
|
||||
|
||||
# Search for linked issues or PRs
|
||||
linked_issue_bodies: List[str] = []
|
||||
for issue in self._search_for_linked_issues(
|
||||
commit_messages
|
||||
) + self._search_for_linked_issues(pr_body):
|
||||
try:
|
||||
linked_issue_bodies.append(get_issue_or_pr_body(int(issue)))
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mentioned = []
|
||||
for test in tests:
|
||||
if (
|
||||
test in commit_messages
|
||||
or test in pr_body
|
||||
or any(test in body for body in linked_issue_bodies)
|
||||
):
|
||||
mentioned.append(test)
|
||||
|
||||
return TestPrioritizations(tests, {TestRun(test): 1 for test in mentioned})
|
@ -4,6 +4,7 @@ import subprocess
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import cast, Dict, List, Set, Union
|
||||
from urllib.request import Request, urlopen
|
||||
from warnings import warn
|
||||
|
||||
from tools.testing.test_run import TestRun
|
||||
@ -48,6 +49,46 @@ def query_changed_files() -> List[str]:
|
||||
return lines
|
||||
|
||||
|
||||
def get_git_commit_info() -> str:
|
||||
"""Gets the commit info since the last commit on the default branch."""
|
||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||
|
||||
merge_base = (
|
||||
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
|
||||
head = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
|
||||
|
||||
base_commit = merge_base
|
||||
if base_commit == head:
|
||||
# We are on the default branch, so check for changes since the last commit
|
||||
base_commit = "HEAD^"
|
||||
|
||||
return (
|
||||
subprocess.check_output(
|
||||
["git", "log", f"{base_commit}..HEAD"],
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
|
||||
|
||||
def get_issue_or_pr_body(number: int) -> str:
|
||||
"""Gets the body of an issue or PR"""
|
||||
github_token = os.environ.get("GITHUB_TOKEN")
|
||||
headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"token {github_token}",
|
||||
}
|
||||
# Despite the 'issues' in the link, this also works for PRs
|
||||
url = f"https://api.github.com/repos/pytorch/pytorch/issues/{number}"
|
||||
with urlopen(Request(url, headers=headers)) as conn:
|
||||
body: str = json.loads(conn.read().decode())["body"]
|
||||
return body
|
||||
|
||||
|
||||
def normalize_ratings(
|
||||
ratings: Dict[TestRun, float], max_value: float
|
||||
) -> Dict[TestRun, float]:
|
||||
|
Reference in New Issue
Block a user