Test TD (test removal) on crossref (#119426)

Current threshold is to cut the bottom 75% of test files, which results in 13 min of tests getting cut.
test_ops, functorch/test_ops, and test_decomp and other really long running test files are not getting cut and make the top 25% to take really long (still 90+ min)

The original plan was to test on rocm but I'm worried about queuing given that cutting 75% of test files only cuts off 13 min, and crossref is rarely referenced by others and people keep talking about getting rid of it, so it's a good alternative

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119426
Approved by: https://github.com/huydhn
This commit is contained in:
Catherine Lee
2024-02-29 18:53:43 +00:00
committed by PyTorch MergeBot
parent 1458f1de66
commit 0290fe65bd
2 changed files with 23 additions and 8 deletions

View File

@ -32,6 +32,7 @@ from torch.testing._internal.common_utils import (
set_cwd,
shell,
TEST_WITH_ASAN,
TEST_WITH_CROSSREF,
TEST_WITH_ROCM,
TEST_WITH_SLOW_GRADCHECK,
)
@ -1164,6 +1165,7 @@ def parse_args():
action="store_true",
help="Enables removing tests based on TD",
default=IS_CI
and TEST_WITH_CROSSREF
and os.getenv("BRANCH", "") != "main"
and not strtobool(os.environ.get("NO_TD", "False")),
)
@ -1462,7 +1464,7 @@ def do_sharding(
test_file_times: Dict[str, float],
test_class_times: Dict[str, Dict[str, float]],
sort_by_time: bool = True,
) -> List[ShardedTest]:
) -> Tuple[float, List[ShardedTest]]:
which_shard, num_shards = get_sharding_opts(options)
# Do sharding
@ -1474,10 +1476,7 @@ def do_sharding(
must_serial=must_serial,
sort_by_time=sort_by_time,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
return selected_tests
return shards[which_shard - 1]
class TestFailure(NamedTuple):
@ -1666,7 +1665,7 @@ def main():
):
self.name = name
self.failures = []
self.sharded_tests = do_sharding(
self.time, self.sharded_tests = do_sharding(
options,
raw_tests,
test_file_times_dict,
@ -1675,7 +1674,7 @@ def main():
)
def __str__(self):
s = f"Name: {self.name}\n"
s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n"
serial = [test for test in self.sharded_tests if must_serial(test)]
parallel = [test for test in self.sharded_tests if not must_serial(test)]
s += f" Serial tests ({len(serial)}):\n"
@ -1684,9 +1683,19 @@ def main():
s += "".join(f" {test}\n" for test in parallel)
return s.strip()
test_batch = TestBatch("all_tests", test_prioritizations.get_all_tests(), False)
percent_to_run = 25 if options.enable_td else 100
print_to_stderr(
f"Running {percent_to_run}% of tests based on TD"
if options.enable_td
else "Running all tests"
)
include, exclude = test_prioritizations.get_top_per_tests(percent_to_run)
test_batch = TestBatch("tests to run", include, False)
test_batch_exclude = TestBatch("excluded", exclude, True)
print_to_stderr(test_batch)
print_to_stderr(test_batch_exclude)
if options.dry_run:
return

View File

@ -112,6 +112,12 @@ class TestPrioritizations:
"""Returns all tests in the TestPrioritizations"""
return [x[1] for x in self._traverse_scores()]
def get_top_per_tests(self, n: int) -> Tuple[List[TestRun], List[TestRun]]:
"""Divides list of tests into two based on the top n% of scores. The
first list is the top, and the second is the rest."""
tests = [x[1] for x in self._traverse_scores()]
return tests[: n * len(tests) // 100], tests[n * len(tests) // 100 :]
def get_info_str(self) -> str:
info = ""