mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Reordering tests experiment (#106347)
Companion with https://github.com/pytorch/test-infra/pull/4424 Uses the file rating generated by the test infra PR to re order tests. For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum. A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now. Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests. Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards. I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results Pull Request resolved: https://github.com/pytorch/pytorch/pull/106347 Approved by: https://github.com/ZainRizvi
This commit is contained in:
committed by
PyTorch MergeBot
parent
884c03d240
commit
f16be5e0d4
@ -128,6 +128,7 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps
|
||||
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
|
||||
python tools/stats/export_test_times.py
|
||||
copy /Y ".pytorch-test-times.json" "%PYTORCH_FINAL_PACKAGE_DIR%"
|
||||
copy /Y ".pytorch-test-file-ratings.json" "%PYTORCH_FINAL_PACKAGE_DIR%"
|
||||
|
||||
:: Also save build/.ninja_log as an artifact
|
||||
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
|
@ -2,6 +2,7 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
||||
|
||||
echo Copying over test times file
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"
|
||||
|
||||
pushd test
|
||||
|
||||
|
@ -23,6 +23,7 @@ if "%SHARD_NUMBER%" == "1" (
|
||||
|
||||
echo Copying over test times file
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"
|
||||
|
||||
echo Run nn tests
|
||||
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
|
||||
|
2
.circleci/config.yml
generated
2
.circleci/config.yml
generated
@ -652,7 +652,7 @@ jobs:
|
||||
- run:
|
||||
name: Archive artifacts into zip
|
||||
command: |
|
||||
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
|
||||
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
|
||||
cp artifacts.zip /Users/distiller/workspace
|
||||
|
||||
- persist_to_workspace:
|
||||
|
@ -177,7 +177,7 @@
|
||||
- run:
|
||||
name: Archive artifacts into zip
|
||||
command: |
|
||||
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
|
||||
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
|
||||
cp artifacts.zip /Users/distiller/workspace
|
||||
|
||||
- persist_to_workspace:
|
||||
|
2
.github/workflows/_linux-build.yml
vendored
2
.github/workflows/_linux-build.yml
vendored
@ -170,7 +170,7 @@ jobs:
|
||||
- name: Archive artifacts into zip
|
||||
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
|
||||
run: |
|
||||
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
|
||||
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json .pytorch-test-file-ratings.json
|
||||
|
||||
- name: Store PyTorch Build Artifacts on S3
|
||||
uses: seemethere/upload-artifact-s3@v5
|
||||
|
2
.github/workflows/_mac-build.yml
vendored
2
.github/workflows/_mac-build.yml
vendored
@ -182,7 +182,7 @@ jobs:
|
||||
- name: Archive artifacts into zip
|
||||
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
|
||||
run: |
|
||||
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
|
||||
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
|
||||
|
||||
- name: Store PyTorch Build Artifacts on GHA
|
||||
uses: actions/upload-artifact@v3
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,6 +19,7 @@ coverage.xml
|
||||
**/.pytorch-disabled-tests.json
|
||||
**/.pytorch-slow-tests.json
|
||||
**/.pytorch-test-times.json
|
||||
**/.pytorch-test-file-ratings.json
|
||||
*/*.pyc
|
||||
*/*.so*
|
||||
*/**/__pycache__
|
||||
|
185
test/run_test.py
185
test/run_test.py
@ -11,9 +11,10 @@ import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Any, cast, Dict, List, Optional, Union
|
||||
from typing import Any, cast, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
import pkg_resources
|
||||
|
||||
@ -40,11 +41,11 @@ try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.stats.export_test_times import TEST_TIMES_FILE
|
||||
from tools.stats.upload_stats_lib import emit_metric
|
||||
from tools.testing.test_selections import (
|
||||
calculate_shards,
|
||||
get_reordered_tests,
|
||||
get_test_case_configs,
|
||||
log_time_savings,
|
||||
NUM_PROCS,
|
||||
ShardedTest,
|
||||
THRESHOLD,
|
||||
@ -1278,7 +1279,9 @@ def exclude_tests(
|
||||
return selected_tests
|
||||
|
||||
|
||||
def must_serial(file: str) -> bool:
|
||||
def must_serial(file: Union[str, ShardedTest]) -> bool:
|
||||
if isinstance(file, ShardedTest):
|
||||
file = file.name
|
||||
return (
|
||||
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
|
||||
or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
|
||||
@ -1408,20 +1411,10 @@ def get_selected_tests(options) -> List[ShardedTest]:
|
||||
)
|
||||
|
||||
selected_tests = [parse_test_module(x) for x in selected_tests]
|
||||
return selected_tests
|
||||
|
||||
# sharding
|
||||
which_shard, num_shards = 1, 1
|
||||
if options.shard:
|
||||
assert len(options.shard) == 2, "Unexpected shard format"
|
||||
assert min(options.shard) > 0, "Shards must be positive numbers"
|
||||
which_shard, num_shards = options.shard
|
||||
assert (
|
||||
which_shard <= num_shards
|
||||
), "Selected shard must be less than or equal to total number of shards"
|
||||
assert num_shards <= len(
|
||||
selected_tests
|
||||
), f"Number of shards must be less than {len(selected_tests)}"
|
||||
|
||||
def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
|
||||
# Download previous test times to make sharding decisions
|
||||
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
|
||||
if os.path.exists(path):
|
||||
@ -1434,14 +1427,35 @@ def get_selected_tests(options) -> List[ShardedTest]:
|
||||
print(
|
||||
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
print("Found test time stats from artifacts")
|
||||
return test_file_times[test_config]
|
||||
|
||||
|
||||
def do_sharding(
|
||||
options,
|
||||
selected_tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
sort_by_time: bool = True,
|
||||
) -> List[ShardedTest]:
|
||||
which_shard, num_shards = 1, 1
|
||||
if options.shard:
|
||||
assert len(options.shard) == 2, "Unexpected shard format"
|
||||
assert min(options.shard) > 0, "Shards must be positive numbers"
|
||||
which_shard, num_shards = options.shard
|
||||
assert (
|
||||
which_shard <= num_shards
|
||||
), "Selected shard must be less than or equal to total number of shards"
|
||||
|
||||
if HAVE_TEST_SELECTION_TOOLS:
|
||||
# Do sharding
|
||||
test_file_times_config = test_file_times.get(test_config, {})
|
||||
shards = calculate_shards(
|
||||
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
|
||||
num_shards,
|
||||
selected_tests,
|
||||
test_file_times,
|
||||
must_serial=must_serial,
|
||||
sort_by_time=sort_by_time,
|
||||
)
|
||||
_, tests_from_shard = shards[which_shard - 1]
|
||||
selected_tests = tests_from_shard
|
||||
@ -1449,9 +1463,14 @@ def get_selected_tests(options) -> List[ShardedTest]:
|
||||
return selected_tests
|
||||
|
||||
|
||||
class TestFailure(NamedTuple):
|
||||
test: str
|
||||
message: str
|
||||
|
||||
|
||||
def run_test_module(
|
||||
test: Union[ShardedTest, str], test_directory: str, options
|
||||
) -> Optional[str]:
|
||||
) -> Optional[TestFailure]:
|
||||
maybe_set_hip_visible_devies()
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
@ -1472,39 +1491,24 @@ def run_test_module(
|
||||
# return code -N, where N is the signal number.
|
||||
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
|
||||
message += f" Received signal: {signal_name}"
|
||||
return message
|
||||
return TestFailure(test, message)
|
||||
|
||||
|
||||
def run_tests(
|
||||
selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
|
||||
selected_tests: List[ShardedTest],
|
||||
test_directory: str,
|
||||
options,
|
||||
failures: List[TestFailure],
|
||||
) -> None:
|
||||
failure_messages = []
|
||||
|
||||
if len(selected_tests) == 0:
|
||||
print_to_stderr(f"No tests in group `{group_name}`")
|
||||
return failure_messages
|
||||
return
|
||||
|
||||
# parallel = in parallel with other files
|
||||
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
|
||||
selected_tests_parallel = [
|
||||
x
|
||||
for x in selected_tests
|
||||
if not must_serial(x.name if isinstance(x, ShardedTest) else x)
|
||||
]
|
||||
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
|
||||
selected_tests_serial = [
|
||||
x for x in selected_tests if x not in selected_tests_parallel
|
||||
]
|
||||
print(f"TEST GROUP: {group_name}")
|
||||
print_to_stderr(
|
||||
"parallel (file granularity) tests :\n {}".format(
|
||||
"\n".join(str(x) for x in selected_tests_parallel)
|
||||
)
|
||||
)
|
||||
print_to_stderr(
|
||||
"serial (file granularity) tests:\n {}".format(
|
||||
"\n ".join(str(x) for x in selected_tests_serial)
|
||||
)
|
||||
)
|
||||
|
||||
# See Note [ROCm parallel CI testing]
|
||||
pool = get_context("spawn").Pool(
|
||||
@ -1523,15 +1527,15 @@ def run_tests(
|
||||
# Take the conftest file from the test directory
|
||||
shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file)
|
||||
|
||||
def handle_error_messages(err_message):
|
||||
if err_message is None:
|
||||
def handle_error_messages(failure: Optional[TestFailure]):
|
||||
if failure is None:
|
||||
return False
|
||||
failure_messages.append(err_message)
|
||||
print_to_stderr(err_message)
|
||||
failures.append(failure)
|
||||
print_to_stderr(failure.message)
|
||||
return True
|
||||
|
||||
def parallel_test_completion_callback(err_message):
|
||||
test_failed = handle_error_messages(err_message)
|
||||
def parallel_test_completion_callback(failure):
|
||||
test_failed = handle_error_messages(failure)
|
||||
if (
|
||||
test_failed
|
||||
and not options.continue_through_error
|
||||
@ -1557,10 +1561,10 @@ def run_tests(
|
||||
if (
|
||||
not options.continue_through_error
|
||||
and not RERUN_DISABLED_TESTS
|
||||
and len(failure_messages) != 0
|
||||
and len(failures) != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
"\n".join(failure_messages)
|
||||
"\n".join(x.message for x in failures)
|
||||
+ "\n\nTip: You can keep running tests even on failure by "
|
||||
"passing --keep-going to run_test.py.\n"
|
||||
"If running on CI, add the 'keep-going' label to "
|
||||
@ -1571,20 +1575,20 @@ def run_tests(
|
||||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
err_message = run_test_module(test, test_directory, options_clone)
|
||||
test_failed = handle_error_messages(err_message)
|
||||
failure = run_test_module(test, test_directory, options_clone)
|
||||
test_failed = handle_error_messages(failure)
|
||||
if (
|
||||
test_failed
|
||||
and not options.continue_through_error
|
||||
and not RERUN_DISABLED_TESTS
|
||||
):
|
||||
raise RuntimeError(err_message)
|
||||
raise RuntimeError(failure.message)
|
||||
|
||||
finally:
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
return failure_messages
|
||||
return
|
||||
|
||||
|
||||
def check_pip_packages() -> None:
|
||||
@ -1611,30 +1615,47 @@ def main():
|
||||
test_directory = str(REPO_ROOT / "test")
|
||||
selected_tests = get_selected_tests(options)
|
||||
|
||||
if options.verbose:
|
||||
print_to_stderr(
|
||||
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
|
||||
)
|
||||
|
||||
if options.dry_run:
|
||||
return
|
||||
|
||||
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
|
||||
shell(["coverage", "erase"])
|
||||
|
||||
prioritized_tests = []
|
||||
remaining_tests = selected_tests
|
||||
general_tests = selected_tests
|
||||
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
|
||||
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
|
||||
log_time_savings(
|
||||
selected_tests,
|
||||
prioritized_tests,
|
||||
is_serial_test_fn=must_serial,
|
||||
num_procs=NUM_PROCS,
|
||||
)
|
||||
|
||||
# downloading test cases configuration to local environment
|
||||
get_test_case_configs(dirpath=test_directory)
|
||||
(prioritized_tests, general_tests) = get_reordered_tests(general_tests)
|
||||
|
||||
metrics_dict = {
|
||||
"prioritized_tests": prioritized_tests,
|
||||
"general_tests": general_tests,
|
||||
"cpp": options.cpp,
|
||||
}
|
||||
|
||||
test_times_dict = download_test_times(TEST_TIMES_FILE)
|
||||
prioritized_tests = do_sharding(
|
||||
options, prioritized_tests, test_times_dict, sort_by_time=False
|
||||
)
|
||||
general_tests = do_sharding(options, general_tests, test_times_dict)
|
||||
|
||||
if options.verbose:
|
||||
|
||||
def print_tests(category, tests):
|
||||
tests_str = "\n ".join(str(x) for x in tests)
|
||||
print_to_stderr(f"{category} tests:\n {tests_str}")
|
||||
|
||||
print_tests(
|
||||
"Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)]
|
||||
)
|
||||
print_tests(
|
||||
"Prioritized serial", [x for x in prioritized_tests if must_serial(x)]
|
||||
)
|
||||
print_tests(
|
||||
"General parallel", [x for x in general_tests if not must_serial(x)]
|
||||
)
|
||||
print_tests("General serial", [x for x in general_tests if must_serial(x)])
|
||||
|
||||
if options.dry_run:
|
||||
return
|
||||
|
||||
if options.dynamo:
|
||||
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
|
||||
@ -1646,17 +1667,17 @@ def main():
|
||||
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
|
||||
failure_messages = []
|
||||
|
||||
prioritized_failures: List[TestFailure] = []
|
||||
general_failures: List[TestFailure] = []
|
||||
start_time = time.time()
|
||||
# First run the prioritized tests, then the remaining tests.
|
||||
try:
|
||||
failure_messages = run_tests(
|
||||
prioritized_tests, test_directory, options, "Prioritized tests"
|
||||
)
|
||||
|
||||
failure_messages += run_tests(
|
||||
remaining_tests, test_directory, options, "General tests"
|
||||
)
|
||||
run_tests(prioritized_tests, test_directory, options, prioritized_failures)
|
||||
metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures]
|
||||
metrics_dict["general_start_time"] = time.time() - start_time
|
||||
run_tests(general_tests, test_directory, options, general_failures)
|
||||
metrics_dict["general_end_time"] = time.time() - start_time
|
||||
metrics_dict["all_failures"] = [x.test for x in general_failures]
|
||||
|
||||
finally:
|
||||
if options.coverage:
|
||||
@ -1671,8 +1692,12 @@ def main():
|
||||
if not PYTORCH_COLLECT_COVERAGE:
|
||||
cov.html_report()
|
||||
|
||||
if len(failure_messages) != 0:
|
||||
for err in failure_messages:
|
||||
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
|
||||
emit_metric("td_experiment_1", metrics_dict)
|
||||
|
||||
all_failures = prioritized_failures + general_failures
|
||||
if len(all_failures) != 0:
|
||||
for _, err in all_failures:
|
||||
print_to_stderr(err)
|
||||
|
||||
# A disabled test is expected to fail, so there is no need to report a failure here
|
||||
|
@ -3,14 +3,16 @@ import sys
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.stats.import_test_stats import get_test_times
|
||||
from tools.stats.import_test_stats import get_test_file_ratings, get_test_times
|
||||
|
||||
TEST_TIMES_FILE = ".pytorch-test-times.json"
|
||||
TEST_FILE_RATINGS_FILE = ".pytorch-test-file-ratings.json"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print(f"Exporting test times from test-infra to {TEST_TIMES_FILE}")
|
||||
get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE)
|
||||
get_test_file_ratings(str(REPO_ROOT), filename=TEST_FILE_RATINGS_FILE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,6 +20,7 @@ IGNORE_DISABLED_ISSUES: List[str] = get_disabled_issues()
|
||||
SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
|
||||
DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
|
||||
|
||||
|
||||
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
||||
|
||||
|
||||
@ -116,3 +117,12 @@ def get_disabled_tests(
|
||||
except Exception:
|
||||
print("Couldn't download test skip set, leaving all tests enabled...")
|
||||
return {}
|
||||
|
||||
|
||||
def get_test_file_ratings(dirpath: str, filename: str) -> Optional[Dict[str, Any]]:
|
||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/file_test_rating.json"
|
||||
try:
|
||||
return fetch_and_cache(dirpath, filename, url, lambda x: x)
|
||||
except Exception:
|
||||
print("Couldn't download test file ratings file, not reordering...")
|
||||
return {}
|
||||
|
@ -263,7 +263,7 @@ class EnvVarMetric:
|
||||
value = os.environ.get(self.env_var)
|
||||
if value is None and self.required:
|
||||
raise ValueError(
|
||||
f"Missing {self.name}. Please set the {self.env_var}"
|
||||
f"Missing {self.name}. Please set the {self.env_var} "
|
||||
"environment variable to pass in this value."
|
||||
)
|
||||
if self.type_conversion_fn:
|
||||
|
@ -394,28 +394,24 @@ class TestParsePrevTests(unittest.TestCase):
|
||||
"tools.testing.test_selections._get_modified_tests",
|
||||
return_value={"test2", "test4"},
|
||||
)
|
||||
@mock.patch(
|
||||
"tools.testing.test_selections._get_file_rating_tests", return_value=["test1"]
|
||||
)
|
||||
def test_get_reordered_tests(
|
||||
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
|
||||
self,
|
||||
mock_get_prev_failing_tests: Any,
|
||||
mock_get_modified_tests: Any,
|
||||
mock_get_file_rating_tests: Any,
|
||||
) -> None:
|
||||
tests = [
|
||||
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
|
||||
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
|
||||
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
|
||||
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
|
||||
]
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
|
||||
expected_prioritized_tests = {"test4", "test2"}
|
||||
expected_remaining_tests = {"test1", "test3", "test5"}
|
||||
expected_prioritized_tests = ["test4", "test2", "test1"]
|
||||
expected_remaining_tests = {"test3", "test5"}
|
||||
|
||||
prioritized_tests, remaining_tests = get_reordered_tests(tests)
|
||||
|
||||
# Just want to check the names of the tests
|
||||
prioritized_tests_name = {test.name for test in prioritized_tests}
|
||||
remaining_tests_name = {test.name for test in remaining_tests}
|
||||
|
||||
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
|
||||
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
|
||||
self.assertListEqual(expected_prioritized_tests, prioritized_tests)
|
||||
self.assertSetEqual(expected_remaining_tests, set(remaining_tests))
|
||||
|
||||
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
|
||||
tests = [
|
||||
|
@ -3,16 +3,20 @@ import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from warnings import warn
|
||||
|
||||
from tools.shared.logging_utils import duration_to_str, pluralize
|
||||
from tools.stats.export_test_times import TEST_FILE_RATINGS_FILE
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
from tools.stats.upload_stats_lib import emit_metric
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||
|
||||
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
|
||||
@ -81,8 +85,8 @@ def get_with_pytest_shard(
|
||||
) -> List[ShardedTest]:
|
||||
sharded_tests: List[ShardedTest] = []
|
||||
for test in tests:
|
||||
duration = test_file_times[test]
|
||||
if duration > THRESHOLD:
|
||||
duration = test_file_times.get(test, None)
|
||||
if duration and duration > THRESHOLD:
|
||||
num_shards = math.ceil(duration / THRESHOLD)
|
||||
for i in range(num_shards):
|
||||
sharded_tests.append(
|
||||
@ -98,20 +102,24 @@ def calculate_shards(
|
||||
tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
must_serial: Optional[Callable[[str], bool]] = None,
|
||||
sort_by_time: bool = True,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
|
||||
known_tests = tests
|
||||
unknown_tests = []
|
||||
|
||||
sorted_tests = sorted(
|
||||
get_with_pytest_shard(known_tests, test_file_times),
|
||||
key=lambda j: j.get_time(),
|
||||
reverse=True,
|
||||
)
|
||||
if sort_by_time:
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests = [x for x in tests if x not in known_tests]
|
||||
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times)
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
||||
|
||||
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
||||
for test in sorted_tests:
|
||||
for test in known_tests:
|
||||
if must_serial(test.name):
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.serial.append(test)
|
||||
@ -127,7 +135,7 @@ def calculate_shards(
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
||||
def _query_changed_test_files() -> List[str]:
|
||||
def _query_changed_files() -> List[str]:
|
||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||
merge_base = (
|
||||
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
|
||||
@ -186,7 +194,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st
|
||||
|
||||
def _get_modified_tests() -> Set[str]:
|
||||
try:
|
||||
changed_files = _query_changed_test_files()
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
@ -271,76 +279,81 @@ def log_time_savings(
|
||||
return max_time_savings_sec
|
||||
|
||||
|
||||
def _get_file_rating_tests() -> List[str]:
|
||||
path = REPO_ROOT / TEST_FILE_RATINGS_FILE
|
||||
if not os.path.exists(path):
|
||||
print(f"could not find path {path}")
|
||||
return []
|
||||
with open(path) as f:
|
||||
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
return []
|
||||
ratings: Dict[str, float] = defaultdict(float)
|
||||
for file in changed_files:
|
||||
for test_file, score in test_file_ratings.get(file, {}).items():
|
||||
ratings[test_file] += score
|
||||
prioritize = sorted(ratings, key=lambda x: ratings[x])
|
||||
return prioritize
|
||||
|
||||
|
||||
def get_reordered_tests(
|
||||
tests: List[ShardedTest],
|
||||
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
||||
tests: List[str],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get the reordered test filename list based on github PR history or git changed file.
|
||||
We prioritize running test files that were changed.
|
||||
"""
|
||||
prioritized_tests: List[str] = []
|
||||
|
||||
def print_tests(tests: Set[str], test_group_description: str) -> None:
|
||||
if not tests:
|
||||
def add_tests(tests_to_add: List[str], test_group_description: str) -> None:
|
||||
if not tests_to_add:
|
||||
return
|
||||
|
||||
print(f"{test_group_description}:")
|
||||
for test in tests:
|
||||
print(f" {test}")
|
||||
for test in tests_to_add:
|
||||
if test in tests:
|
||||
print(f" {test}")
|
||||
if test not in prioritized_tests:
|
||||
prioritized_tests.append(test)
|
||||
|
||||
prioritized_tests: Set[str] = set()
|
||||
|
||||
pri_test = _get_previously_failing_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will prioritized because they previously failed"
|
||||
add_tests(
|
||||
sorted(_get_previously_failing_tests()),
|
||||
"If run, these tests will prioritized because they previously failed",
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
pri_test |= _get_modified_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will be prioritized because they were modified"
|
||||
add_tests(
|
||||
sorted(_get_modified_tests()),
|
||||
"If run, these tests will be prioritized because they were modified",
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
bring_to_front = []
|
||||
the_rest = []
|
||||
add_tests(
|
||||
_get_file_rating_tests(),
|
||||
"If run, these tests will be preioritized for an experiment in TD",
|
||||
)
|
||||
|
||||
for test in tests:
|
||||
if test.name in prioritized_tests:
|
||||
bring_to_front.append(test)
|
||||
else:
|
||||
the_rest.append(test)
|
||||
prioritized_tests = [x for x in prioritized_tests if x in tests]
|
||||
the_rest = [x for x in tests if x not in prioritized_tests]
|
||||
|
||||
if len(tests) != len(bring_to_front) + len(the_rest):
|
||||
print(
|
||||
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
||||
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
||||
)
|
||||
return ([], tests)
|
||||
|
||||
prioritized_test_names = []
|
||||
remaining_test_names = []
|
||||
if bring_to_front:
|
||||
if prioritized_tests:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
|
||||
|
||||
prioritized_test_names = [t.name for t in bring_to_front]
|
||||
print(f"Prioritized: {prioritized_test_names}")
|
||||
remaining_test_names = [t.name for t in the_rest]
|
||||
print(f"The Rest: {remaining_test_names}")
|
||||
else:
|
||||
print("Didn't find any tests to prioritize")
|
||||
print(
|
||||
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
|
||||
)
|
||||
|
||||
emit_metric(
|
||||
"test_reordering_prioritized_tests",
|
||||
{
|
||||
"prioritized_test_cnt": len(bring_to_front),
|
||||
"prioritized_test_cnt": len(prioritized_tests),
|
||||
"total_test_cnt": len(tests),
|
||||
"prioritized_tests": prioritized_test_names,
|
||||
"remaining_tests": remaining_test_names,
|
||||
"prioritized_tests": prioritized_tests,
|
||||
"remaining_tests": the_rest,
|
||||
},
|
||||
)
|
||||
|
||||
return (bring_to_front, the_rest)
|
||||
return (prioritized_tests, the_rest)
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
|
Reference in New Issue
Block a user