Always run prioritized tests first, even if they're expected to run serially (#100748)

Today, we prioritize running test files that were edited in the user's PR, with the idea being to run them before we run any other test.

Except, if the modified test is supposed to run serially, then we still end up running it after all the parallelized tests have finished running.

This PR fixes that to _always_ run the prioritized tests before the regular tests, regardless of if the test is supposed to run serially or in parallel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100748
Approved by: https://github.com/huydhn
This commit is contained in:
Zain Rizvi
2023-05-08 20:20:37 +00:00
committed by PyTorch MergeBot
parent c4bbeb5b8a
commit 95f191a248
2 changed files with 64 additions and 34 deletions

View File

@ -1408,44 +1408,25 @@ def run_test_module(test: ShardedTest, test_directory: str, options) -> Optional
return message
def main():
options = parse_args()
test_directory = str(REPO_ROOT / "test")
selected_tests = get_selected_tests(options)
if options.verbose:
print_to_stderr(
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
)
if options.dry_run:
return
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
shell(["coverage", "erase"])
if IS_CI:
selected_tests = get_reordered_tests(selected_tests)
# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)
if options.dynamo:
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
elif options.inductor:
os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
def run_tests(
selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
) -> None:
failure_messages = []
if len(selected_tests) == 0:
print_to_stderr(f"No tests in group `{group_name}`")
return failure_messages
# parallel = in parallel with other files
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
selected_tests_parallel = [x for x in selected_tests if not must_serial(x.name)]
selected_tests_serial = [
x for x in selected_tests if x not in selected_tests_parallel
]
print(f"TEST GROUP: {group_name}")
print_to_stderr(
"parallel (file granularity) tests:\n {}".format(
"\n ".join(str(x) for x in selected_tests_parallel)
"parallel (file granularity) tests :\n {}".format(
"\n".join(str(x) for x in selected_tests_parallel)
)
)
print_to_stderr(
@ -1458,7 +1439,6 @@ def main():
pool = get_context("spawn").Pool(
NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
)
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
def handle_error_messages(err_message):
if err_message is None:
@ -1504,10 +1484,58 @@ def main():
test_failed = handle_error_messages(err_message)
if test_failed and not options.continue_through_error:
raise RuntimeError(err_message)
finally:
pool.terminate()
pool.join()
return failure_messages
def main():
options = parse_args()
test_directory = str(REPO_ROOT / "test")
selected_tests = get_selected_tests(options)
if options.verbose:
print_to_stderr(
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
)
if options.dry_run:
return
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
shell(["coverage", "erase"])
prioritized_tests = []
remaining_tests = selected_tests
if IS_CI:
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)
if options.dynamo:
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
elif options.inductor:
os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
failure_messages = []
# First run the prioritized tests, then the remaining tests.
try:
failure_messages = run_tests(
prioritized_tests, test_directory, options, "Prioritized tests"
)
failure_messages += run_tests(
remaining_tests, test_directory, options, "General tests"
)
finally:
if options.coverage:
from coverage import Coverage

View File

@ -127,7 +127,9 @@ def _query_changed_test_files() -> List[str]:
return lines
def get_reordered_tests(tests: List[ShardedTest]) -> List[ShardedTest]:
def get_reordered_tests(
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
@ -138,7 +140,7 @@ def get_reordered_tests(tests: List[ShardedTest]) -> List[ShardedTest]:
changed_files = _query_changed_test_files()
except Exception:
# If unable to get changed files from git, quit without doing any sorting
return tests
return ([], tests)
prefix = f"test{os.path.sep}"
prioritized_tests = [
@ -161,13 +163,13 @@ def get_reordered_tests(tests: List[ShardedTest]) -> List[ShardedTest]:
f"reordering tests for PR:\n"
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
)
return bring_to_front + the_rest
return (bring_to_front, the_rest)
else:
print(
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return tests
return ([], tests)
def get_test_case_configs(dirpath: str) -> None: