Revert "Reordering tests experiment (#106347)"

This reverts commit 7dfab082be9eaeeee95c7b0363e59c824c6a9009.

Reverted https://github.com/pytorch/pytorch/pull/106347 on behalf of https://github.com/clee2000 due to probably broke sharding ([comment](https://github.com/pytorch/pytorch/pull/106347#issuecomment-1675542738))
This commit is contained in:
PyTorch MergeBot
2023-08-11 23:59:48 +00:00
parent c9cbcb2449
commit 9858edd99f
6 changed files with 155 additions and 210 deletions

1
.gitignore vendored
View File

@ -19,7 +19,6 @@ coverage.xml
**/.pytorch-disabled-tests.json
**/.pytorch-slow-tests.json
**/.pytorch-test-times.json
**/.pytorch-test-file-ratings.json
*/*.pyc
*/*.so*
*/**/__pycache__

View File

@ -11,10 +11,9 @@ import signal
import subprocess
import sys
import tempfile
import time
from datetime import datetime
from distutils.version import LooseVersion
from typing import Any, cast, Dict, List, NamedTuple, Optional, Union
from typing import Any, cast, Dict, List, Optional, Union
import pkg_resources
@ -41,11 +40,11 @@ try:
# using tools/ to optimize test run.
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.export_test_times import TEST_TIMES_FILE
from tools.stats.upload_stats_lib import emit_metric
from tools.testing.test_selections import (
calculate_shards,
get_reordered_tests,
get_test_case_configs,
log_time_savings,
NUM_PROCS,
ShardedTest,
THRESHOLD,
@ -1279,9 +1278,7 @@ def exclude_tests(
return selected_tests
def must_serial(file: Union[str, ShardedTest]) -> bool:
if isinstance(file, ShardedTest):
file = file.name
def must_serial(file: str) -> bool:
return (
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
@ -1411,10 +1408,20 @@ def get_selected_tests(options) -> List[ShardedTest]:
)
selected_tests = [parse_test_module(x) for x in selected_tests]
return selected_tests
# sharding
which_shard, num_shards = 1, 1
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"
assert num_shards <= len(
selected_tests
), f"Number of shards must be less than {len(selected_tests)}"
def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
# Download previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
if os.path.exists(path):
@ -1427,35 +1434,14 @@ def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
return {}
else:
print("Found test time stats from artifacts")
return test_file_times[test_config]
def do_sharding(
options,
selected_tests: List[str],
test_file_times: Dict[str, float],
sort_by_time: bool = True,
) -> List[ShardedTest]:
which_shard, num_shards = 1, 1
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"
if HAVE_TEST_SELECTION_TOOLS:
# Do sharding
test_file_times_config = test_file_times.get(test_config, {})
shards = calculate_shards(
num_shards,
selected_tests,
test_file_times,
must_serial=must_serial,
sort_by_time=sort_by_time,
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
@ -1463,14 +1449,9 @@ def do_sharding(
return selected_tests
class TestFailure(NamedTuple):
test: str
message: str
def run_test_module(
test: Union[ShardedTest, str], test_directory: str, options
) -> Optional[TestFailure]:
) -> Optional[str]:
maybe_set_hip_visible_devies()
# Printing the date here can help diagnose which tests are slow
@ -1491,24 +1472,39 @@ def run_test_module(
# return code -N, where N is the signal number.
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
message += f" Received signal: {signal_name}"
return TestFailure(test, message)
return message
def run_tests(
selected_tests: List[ShardedTest],
test_directory: str,
options,
failures: List[TestFailure],
selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
) -> None:
failure_messages = []
if len(selected_tests) == 0:
return
print_to_stderr(f"No tests in group `{group_name}`")
return failure_messages
# parallel = in parallel with other files
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
selected_tests_parallel = [
x
for x in selected_tests
if not must_serial(x.name if isinstance(x, ShardedTest) else x)
]
selected_tests_serial = [
x for x in selected_tests if x not in selected_tests_parallel
]
print(f"TEST GROUP: {group_name}")
print_to_stderr(
"parallel (file granularity) tests :\n {}".format(
"\n".join(str(x) for x in selected_tests_parallel)
)
)
print_to_stderr(
"serial (file granularity) tests:\n {}".format(
"\n ".join(str(x) for x in selected_tests_serial)
)
)
# See Note [ROCm parallel CI testing]
pool = get_context("spawn").Pool(
@ -1527,15 +1523,15 @@ def run_tests(
# Take the conftest file from the test directory
shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file)
def handle_error_messages(failure: Optional[TestFailure]):
if failure is None:
def handle_error_messages(err_message):
if err_message is None:
return False
failures.append(failure)
print_to_stderr(failure.message)
failure_messages.append(err_message)
print_to_stderr(err_message)
return True
def parallel_test_completion_callback(failure):
test_failed = handle_error_messages(failure)
def parallel_test_completion_callback(err_message):
test_failed = handle_error_messages(err_message)
if (
test_failed
and not options.continue_through_error
@ -1561,10 +1557,10 @@ def run_tests(
if (
not options.continue_through_error
and not RERUN_DISABLED_TESTS
and len(failures) != 0
and len(failure_messages) != 0
):
raise RuntimeError(
"\n".join(x.message for x in failures)
"\n".join(failure_messages)
+ "\n\nTip: You can keep running tests even on failure by "
"passing --keep-going to run_test.py.\n"
"If running on CI, add the 'keep-going' label to "
@ -1575,20 +1571,20 @@ def run_tests(
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
failure = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(failure)
err_message = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(err_message)
if (
test_failed
and not options.continue_through_error
and not RERUN_DISABLED_TESTS
):
raise RuntimeError(failure.message)
raise RuntimeError(err_message)
finally:
pool.terminate()
pool.join()
return
return failure_messages
def check_pip_packages() -> None:
@ -1615,47 +1611,30 @@ def main():
test_directory = str(REPO_ROOT / "test")
selected_tests = get_selected_tests(options)
if options.verbose:
print_to_stderr(
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
)
if options.dry_run:
return
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
shell(["coverage", "erase"])
prioritized_tests = []
general_tests = selected_tests
remaining_tests = selected_tests
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
log_time_savings(
selected_tests,
prioritized_tests,
is_serial_test_fn=must_serial,
num_procs=NUM_PROCS,
)
# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)
(prioritized_tests, general_tests) = get_reordered_tests(general_tests)
metrics_dict = {
"prioritized_tests": prioritized_tests,
"general_tests": general_tests,
"cpp": options.cpp,
}
test_times_dict = download_test_times(TEST_TIMES_FILE)
prioritized_tests = do_sharding(
options, prioritized_tests, test_times_dict, sort_by_time=False
)
general_tests = do_sharding(options, general_tests, test_times_dict)
if options.verbose:
def print_tests(category, tests):
tests_str = "\n ".join(str(x) for x in tests)
print_to_stderr(f"{category} tests:\n {tests_str}")
print_tests(
"Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)]
)
print_tests(
"Prioritized serial", [x for x in prioritized_tests if must_serial(x)]
)
print_tests(
"General parallel", [x for x in general_tests if not must_serial(x)]
)
print_tests("General serial", [x for x in general_tests if must_serial(x)])
if options.dry_run:
return
if options.dynamo:
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
@ -1667,17 +1646,17 @@ def main():
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
prioritized_failures: List[TestFailure] = []
general_failures: List[TestFailure] = []
start_time = time.time()
failure_messages = []
# First run the prioritized tests, then the remaining tests.
try:
run_tests(prioritized_tests, test_directory, options, prioritized_failures)
metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures]
metrics_dict["general_start_time"] = time.time() - start_time
run_tests(general_tests, test_directory, options, general_failures)
metrics_dict["general_end_time"] = time.time() - start_time
metrics_dict["all_failures"] = [x.test for x in general_failures]
failure_messages = run_tests(
prioritized_tests, test_directory, options, "Prioritized tests"
)
failure_messages += run_tests(
remaining_tests, test_directory, options, "General tests"
)
finally:
if options.coverage:
@ -1692,12 +1671,8 @@ def main():
if not PYTORCH_COLLECT_COVERAGE:
cov.html_report()
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
emit_metric("td_experiment_1", metrics_dict)
all_failures = prioritized_failures + general_failures
if len(all_failures) != 0:
for _, err in all_failures:
if len(failure_messages) != 0:
for err in failure_messages:
print_to_stderr(err)
# A disabled test is expected to fail, so there is no need to report a failure here

View File

@ -19,8 +19,6 @@ IGNORE_DISABLED_ISSUES: List[str] = get_disabled_issues()
SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
TEST_FILE_RATINGS_FILE = ".pytorch-test-file-ratings.json"
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
@ -118,14 +116,3 @@ def get_disabled_tests(
except Exception:
print("Couldn't download test skip set, leaving all tests enabled...")
return {}
def get_test_file_ratings(
dirpath: str, filename: str = TEST_FILE_RATINGS_FILE
) -> Optional[Dict[str, Any]]:
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/file_test_rating.json"
try:
return fetch_and_cache(dirpath, filename, url, lambda x: x)
except Exception:
print("Couldn't download test file ratings file, not reordering...")
return {}

View File

@ -249,7 +249,7 @@ class EnvVarMetric:
value = os.environ.get(self.env_var)
if value is None and self.required:
raise ValueError(
f"Missing {self.name}. Please set the {self.env_var} "
f"Missing {self.name}. Please set the {self.env_var}"
"environment variable to pass in this value."
)
if self.type_conversion_fn:

View File

@ -394,24 +394,28 @@ class TestParsePrevTests(unittest.TestCase):
"tools.testing.test_selections._get_modified_tests",
return_value={"test2", "test4"},
)
@mock.patch(
"tools.testing.test_selections._get_file_rating_tests", return_value=["test1"]
)
def test_get_reordered_tests(
self,
mock_get_prev_failing_tests: Any,
mock_get_modified_tests: Any,
mock_get_file_rating_tests: Any,
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
) -> None:
tests = ["test1", "test2", "test3", "test4", "test5"]
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
]
expected_prioritized_tests = ["test4", "test2", "test1"]
expected_remaining_tests = {"test3", "test5"}
expected_prioritized_tests = {"test4", "test2"}
expected_remaining_tests = {"test1", "test3", "test5"}
prioritized_tests, remaining_tests = get_reordered_tests(tests)
self.assertListEqual(expected_prioritized_tests, prioritized_tests)
self.assertSetEqual(expected_remaining_tests, set(remaining_tests))
# Just want to check the names of the tests
prioritized_tests_name = {test.name for test in prioritized_tests}
remaining_tests_name = {test.name for test in remaining_tests}
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
tests = [

View File

@ -3,24 +3,16 @@ import json
import math
import os
import subprocess
from collections import defaultdict
from pathlib import Path
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from warnings import warn
from tools.shared.logging_utils import duration_to_str, pluralize
from tools.stats.import_test_stats import (
get_disabled_tests,
get_slow_tests,
get_test_file_ratings,
TEST_FILE_RATINGS_FILE,
)
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
from tools.stats.upload_stats_lib import emit_metric
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
@ -89,8 +81,8 @@ def get_with_pytest_shard(
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
for test in tests:
duration = test_file_times.get(test, None)
if duration and duration > THRESHOLD:
duration = test_file_times[test]
if duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
sharded_tests.append(
@ -106,24 +98,20 @@ def calculate_shards(
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
sort_by_time: bool = True,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
known_tests = tests
unknown_tests = []
known_tests = [x for x in tests if x in test_file_times]
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
if sort_by_time:
known_tests = [x for x in tests if x in test_file_times]
unknown_tests = [x for x in tests if x not in known_tests]
known_tests = get_with_pytest_shard(known_tests, test_file_times)
if sort_by_time:
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
sorted_tests = sorted(
get_with_pytest_shard(known_tests, test_file_times),
key=lambda j: j.get_time(),
reverse=True,
)
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
for test in known_tests:
for test in sorted_tests:
if must_serial(test.name):
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
@ -139,7 +127,7 @@ def calculate_shards(
return [job.convert_to_tuple() for job in sharded_jobs]
def _query_changed_files() -> List[str]:
def _query_changed_test_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
merge_base = (
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
@ -198,7 +186,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st
def _get_modified_tests() -> Set[str]:
try:
changed_files = _query_changed_files()
changed_files = _query_changed_test_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
# If unable to get changed files from git, quit without doing any sorting
@ -283,86 +271,78 @@ def log_time_savings(
return max_time_savings_sec
def _get_file_rating_tests() -> List[str]:
path = REPO_ROOT / "test" / TEST_FILE_RATINGS_FILE
if not os.path.exists(path):
print(f"could not find path {path}")
return []
with open(path) as f:
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
try:
changed_files = _query_changed_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
return []
ratings: Dict[str, float] = defaultdict(float)
for file in changed_files:
for test_file, score in test_file_ratings.get(file, {}).items():
ratings[test_file] += score
prioritize = sorted(ratings, key=lambda x: ratings[x])
return prioritize
def get_reordered_tests(
tests: List[str],
) -> Tuple[List[str], List[str]]:
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
prioritized_tests: List[str] = []
def add_tests(
tests_to_add: Union[List[str], Set[str]], test_group_description: str
) -> None:
if not tests_to_add:
def print_tests(tests: Set[str], test_group_description: str) -> None:
if not tests:
return
print(f"{test_group_description}:")
for test in tests_to_add:
if test in tests:
print(f" {test}")
if test not in prioritized_tests:
prioritized_tests.append(test)
for test in tests:
print(f" {test}")
add_tests(
_get_previously_failing_tests(),
"If run, these tests will prioritized because they previously failed",
prioritized_tests: Set[str] = set()
pri_test = _get_previously_failing_tests()
print_tests(
pri_test, "If run, these tests will prioritized because they previously failed"
)
prioritized_tests |= pri_test
add_tests(
_get_modified_tests(),
"If run, these tests will be prioritized because they were modified",
pri_test |= _get_modified_tests()
print_tests(
pri_test, "If run, these tests will be prioritized because they were modified"
)
prioritized_tests |= pri_test
add_tests(
_get_file_rating_tests(),
"If run, these tests will be preioritized for an experiment in TD",
)
bring_to_front = []
the_rest = []
prioritized_tests = [x for x in prioritized_tests if x in tests]
the_rest = [x for x in tests if x not in prioritized_tests]
for test in tests:
if test.name in prioritized_tests:
bring_to_front.append(test)
else:
the_rest.append(test)
if prioritized_tests:
test_cnt_str = pluralize(len(tests), "test")
if len(tests) != len(bring_to_front) + len(the_rest):
print(
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return ([], tests)
prioritized_test_names = []
remaining_test_names = []
if bring_to_front:
test_cnt_str = pluralize(len(tests), "test")
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
prioritized_test_names = [t.name for t in bring_to_front]
print(f"Prioritized: {prioritized_test_names}")
remaining_test_names = [t.name for t in the_rest]
print(f"The Rest: {remaining_test_names}")
else:
print("Didn't find any tests to prioritize")
emit_metric(
"test_reordering_prioritized_tests",
{
"prioritized_test_cnt": len(prioritized_tests),
"prioritized_test_cnt": len(bring_to_front),
"total_test_cnt": len(tests),
"prioritized_tests": prioritized_tests,
"remaining_tests": the_rest,
"prioritized_tests": prioritized_test_names,
"remaining_tests": remaining_test_names,
},
)
return (prioritized_tests, the_rest)
return (bring_to_front, the_rest)
def get_test_case_configs(dirpath: str) -> None:
get_slow_tests(dirpath=dirpath)
get_disabled_tests(dirpath=dirpath)
get_test_file_ratings(dirpath=dirpath)