mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TD] Filepath heuristic also looks at file name (#140170)
Filepath heuristic also now takes into account the file name, not just directories A bit of refactoring Pull Request resolved: https://github.com/pytorch/pytorch/pull/140170 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
5f7ea7ca6a
commit
b742d11b1c
@ -155,11 +155,12 @@ class TestParsePrevTests(TestTD):
|
||||
|
||||
class TestFilePath(TestTD):
|
||||
def test_get_keywords(self) -> None:
|
||||
self.assertEqual(get_keywords("test/test_car.py"), [])
|
||||
self.assertEqual(get_keywords("test/nn/test_amp.py"), ["nn"])
|
||||
self.assertEqual(get_keywords("torch/nn/test_amp.py"), ["nn"])
|
||||
self.assertEqual(get_keywords("test/test_car.py"), ["car"])
|
||||
self.assertEqual(get_keywords("test/nn/test_amp.py"), ["nn", "amp"])
|
||||
self.assertEqual(get_keywords("torch/nn/test_amp.py"), ["nn", "amp"])
|
||||
self.assertEqual(
|
||||
get_keywords("torch/nn/mixed_precision/test_amp.py"), ["nn", "amp"]
|
||||
get_keywords("torch/nn/mixed_precision/test_something.py"),
|
||||
["nn", "amp", "something"],
|
||||
)
|
||||
|
||||
def test_match_keywords(self) -> None:
|
||||
|
@ -30,19 +30,6 @@ keyword_synonyms: dict[str, list[str]] = {
|
||||
"inductor": ["dynamo", "export"], # not actually synonyms but they interact a lot
|
||||
}
|
||||
|
||||
not_keyword = [
|
||||
"torch",
|
||||
"test",
|
||||
"tests",
|
||||
"util",
|
||||
"utils",
|
||||
"func",
|
||||
"src",
|
||||
"c",
|
||||
"ns",
|
||||
"tools",
|
||||
"internal",
|
||||
]
|
||||
|
||||
custom_matchers: dict[str, Callable[[str], bool]] = {
|
||||
"nn": lambda x: "nn" in x.replace("onnx", "_"),
|
||||
@ -50,16 +37,36 @@ custom_matchers: dict[str, Callable[[str], bool]] = {
|
||||
}
|
||||
|
||||
|
||||
def is_valid_keyword(keyword: str) -> bool:
|
||||
not_keyword = [
|
||||
"torch",
|
||||
"test",
|
||||
"tests",
|
||||
"util",
|
||||
"utils",
|
||||
"func",
|
||||
"src",
|
||||
"c",
|
||||
"ns",
|
||||
"tools",
|
||||
"internal",
|
||||
]
|
||||
return keyword == "nn" or (keyword not in not_keyword and len(keyword) > 2)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_keywords(file: str) -> list[str]:
|
||||
keywords = []
|
||||
for folder in Path(file).parts[:-1]:
|
||||
folder = sanitize_folder_name(folder)
|
||||
folder = sanitize_name(folder)
|
||||
keywords.append(folder)
|
||||
return [kw for kw in keywords if kw not in not_keyword]
|
||||
|
||||
file_name = Path(file).stem.split("_")
|
||||
keywords.extend([sanitize_name(x) for x in file_name])
|
||||
return [kw for kw in keywords if is_valid_keyword(kw)]
|
||||
|
||||
|
||||
def sanitize_folder_name(folder_name: str) -> str:
|
||||
def sanitize_name(folder_name: str) -> str:
|
||||
if folder_name.startswith("_"):
|
||||
folder_name = folder_name[1:]
|
||||
|
||||
@ -81,6 +88,22 @@ def file_matches_keyword(file: str, keyword: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def get_freq_dict(tests: list[str], changed_files: list[str]) -> dict[str, int]:
|
||||
keyword_frequency: dict[str, int] = defaultdict(int)
|
||||
for cf in changed_files:
|
||||
keywords = get_keywords(cf)
|
||||
for keyword in keywords:
|
||||
keyword_frequency[keyword] += 1
|
||||
|
||||
test_ratings: dict[str, int] = defaultdict(int)
|
||||
|
||||
for test in tests:
|
||||
for keyword, frequency in keyword_frequency.items():
|
||||
if file_matches_keyword(test, keyword):
|
||||
test_ratings[test] += frequency
|
||||
return test_ratings
|
||||
|
||||
|
||||
class Filepath(HeuristicInterface):
|
||||
# Heuristic based on folders in the file path. Takes each folder of each
|
||||
# changed file and attempts to find matches based on those folders
|
||||
@ -88,25 +111,33 @@ class Filepath(HeuristicInterface):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
|
||||
keyword_frequency: dict[str, int] = defaultdict(int)
|
||||
try:
|
||||
changed_files = query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
changed_files = []
|
||||
|
||||
for cf in changed_files:
|
||||
keywords = get_keywords(cf)
|
||||
for keyword in keywords:
|
||||
keyword_frequency[keyword] += 1
|
||||
|
||||
test_ratings: dict[str, float] = defaultdict(float)
|
||||
|
||||
for test in tests:
|
||||
for keyword, frequency in keyword_frequency.items():
|
||||
if file_matches_keyword(test, keyword):
|
||||
test_ratings[test] += frequency
|
||||
test_ratings = {TestRun(k): v for (k, v) in test_ratings.items() if k in tests}
|
||||
test_ratings = get_freq_dict(tests, changed_files)
|
||||
test_ratings = {
|
||||
TestRun(k): float(v) for (k, v) in test_ratings.items() if k in tests
|
||||
}
|
||||
return TestPrioritizations(
|
||||
tests, normalize_ratings(test_ratings, 0.25, min_value=0.125)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick thing so you can call the heuristic from the command line with a sha
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tools.testing.discover_tests import TESTS
|
||||
|
||||
git_diff = f"git diff --name-only {sys.argv[1]} {sys.argv[1]}^"
|
||||
changed_files = os.popen(git_diff).read().split("\n")
|
||||
freq_dict = get_freq_dict(
|
||||
TESTS, [x for x in changed_files if x != "" and not x.startswith("test")]
|
||||
)
|
||||
for k, v in sorted(freq_dict.items(), key=lambda x: x[1], reverse=False):
|
||||
print(k, v)
|
||||
print(changed_files)
|
||||
|
Reference in New Issue
Block a user