mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Always run prioritized tests first, even if they're expected to run serially (#100748)
Today, we prioritize running test files that were edited in the user's PR, with the idea being to run them before we run any other test. Except, if the modified test is supposed to run serially, then we still end up running it after all the parallelized tests have finished running. This PR fixes that to _always_ run the prioritized tests before the regular tests, regardless of if the test is supposed to run serially or in parallel Pull Request resolved: https://github.com/pytorch/pytorch/pull/100748 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
c4bbeb5b8a
commit
95f191a248
@ -1408,44 +1408,25 @@ def run_test_module(test: ShardedTest, test_directory: str, options) -> Optional
|
||||
return message
|
||||
|
||||
|
||||
def main():
|
||||
options = parse_args()
|
||||
|
||||
test_directory = str(REPO_ROOT / "test")
|
||||
selected_tests = get_selected_tests(options)
|
||||
|
||||
if options.verbose:
|
||||
print_to_stderr(
|
||||
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
|
||||
)
|
||||
|
||||
if options.dry_run:
|
||||
return
|
||||
|
||||
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
|
||||
shell(["coverage", "erase"])
|
||||
|
||||
if IS_CI:
|
||||
selected_tests = get_reordered_tests(selected_tests)
|
||||
# downloading test cases configuration to local environment
|
||||
get_test_case_configs(dirpath=test_directory)
|
||||
|
||||
if options.dynamo:
|
||||
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
|
||||
elif options.inductor:
|
||||
os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
|
||||
|
||||
def run_tests(
|
||||
selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
|
||||
) -> None:
|
||||
failure_messages = []
|
||||
|
||||
if len(selected_tests) == 0:
|
||||
print_to_stderr(f"No tests in group `{group_name}`")
|
||||
return failure_messages
|
||||
|
||||
# parallel = in parallel with other files
|
||||
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
|
||||
selected_tests_parallel = [x for x in selected_tests if not must_serial(x.name)]
|
||||
selected_tests_serial = [
|
||||
x for x in selected_tests if x not in selected_tests_parallel
|
||||
]
|
||||
print(f"TEST GROUP: {group_name}")
|
||||
print_to_stderr(
|
||||
"parallel (file granularity) tests:\n {}".format(
|
||||
"\n ".join(str(x) for x in selected_tests_parallel)
|
||||
"parallel (file granularity) tests :\n {}".format(
|
||||
"\n".join(str(x) for x in selected_tests_parallel)
|
||||
)
|
||||
)
|
||||
print_to_stderr(
|
||||
@ -1458,7 +1439,6 @@ def main():
|
||||
pool = get_context("spawn").Pool(
|
||||
NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
|
||||
)
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
|
||||
def handle_error_messages(err_message):
|
||||
if err_message is None:
|
||||
@ -1504,10 +1484,58 @@ def main():
|
||||
test_failed = handle_error_messages(err_message)
|
||||
if test_failed and not options.continue_through_error:
|
||||
raise RuntimeError(err_message)
|
||||
|
||||
finally:
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
return failure_messages
|
||||
|
||||
|
||||
def main():
|
||||
options = parse_args()
|
||||
|
||||
test_directory = str(REPO_ROOT / "test")
|
||||
selected_tests = get_selected_tests(options)
|
||||
|
||||
if options.verbose:
|
||||
print_to_stderr(
|
||||
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
|
||||
)
|
||||
|
||||
if options.dry_run:
|
||||
return
|
||||
|
||||
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
|
||||
shell(["coverage", "erase"])
|
||||
|
||||
prioritized_tests = []
|
||||
remaining_tests = selected_tests
|
||||
if IS_CI:
|
||||
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
|
||||
# downloading test cases configuration to local environment
|
||||
get_test_case_configs(dirpath=test_directory)
|
||||
|
||||
if options.dynamo:
|
||||
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
|
||||
elif options.inductor:
|
||||
os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
|
||||
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
|
||||
failure_messages = []
|
||||
|
||||
# First run the prioritized tests, then the remaining tests.
|
||||
try:
|
||||
failure_messages = run_tests(
|
||||
prioritized_tests, test_directory, options, "Prioritized tests"
|
||||
)
|
||||
|
||||
failure_messages += run_tests(
|
||||
remaining_tests, test_directory, options, "General tests"
|
||||
)
|
||||
|
||||
finally:
|
||||
if options.coverage:
|
||||
from coverage import Coverage
|
||||
|
||||
|
@ -127,7 +127,9 @@ def _query_changed_test_files() -> List[str]:
|
||||
return lines
|
||||
|
||||
|
||||
def get_reordered_tests(tests: List[ShardedTest]) -> List[ShardedTest]:
|
||||
def get_reordered_tests(
|
||||
tests: List[ShardedTest],
|
||||
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
||||
"""
|
||||
Get the reordered test filename list based on github PR history or git changed file.
|
||||
We prioritize running test files that were changed.
|
||||
@ -138,7 +140,7 @@ def get_reordered_tests(tests: List[ShardedTest]) -> List[ShardedTest]:
|
||||
changed_files = _query_changed_test_files()
|
||||
except Exception:
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
return tests
|
||||
return ([], tests)
|
||||
|
||||
prefix = f"test{os.path.sep}"
|
||||
prioritized_tests = [
|
||||
@ -161,13 +163,13 @@ def get_reordered_tests(tests: List[ShardedTest]) -> List[ShardedTest]:
|
||||
f"reordering tests for PR:\n"
|
||||
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
|
||||
)
|
||||
return bring_to_front + the_rest
|
||||
return (bring_to_front, the_rest)
|
||||
else:
|
||||
print(
|
||||
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
||||
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
||||
)
|
||||
return tests
|
||||
return ([], tests)
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
|
Reference in New Issue
Block a user