Improve time savings calculation math for test reordering (#102411)

Use a more accurate method that accounts for tests being run in parallel

Right now we still log results to the console, but later it'll get logged to Rockset for better tracking
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102411
Approved by: https://github.com/huydhn, https://github.com/malfet
This commit is contained in:
Zain Rizvi
2023-05-31 23:51:27 +00:00
committed by PyTorch MergeBot
parent 693114c0a2
commit c84f246c83
3 changed files with 167 additions and 10 deletions

View File

@ -44,6 +44,7 @@ try:
calculate_shards,
get_reordered_tests,
get_test_case_configs,
log_time_savings,
NUM_PROCS,
ShardedTest,
THRESHOLD,
@ -1610,6 +1611,13 @@ def main():
remaining_tests = selected_tests
if IS_CI:
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
log_time_savings(
selected_tests,
prioritized_tests,
is_serial_test_fn=must_serial,
num_procs=NUM_PROCS,
)
# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)

View File

@ -16,6 +16,7 @@ try:
_get_previously_failing_tests,
calculate_shards,
get_reordered_tests,
log_time_savings,
ShardedTest,
THRESHOLD,
)
@ -344,6 +345,10 @@ def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
return file_object
def never_serial(test_name: str) -> bool:
return False
class TestParsePrevTests(unittest.TestCase):
@mock.patch("pathlib.Path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
@ -412,6 +417,98 @@ class TestParsePrevTests(unittest.TestCase):
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=7.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=5.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=2.0),
ShardedTest(name="test6", shard=1, num_shards=2, time=1.0),
]
prioritized_tests = [
test for test in tests if test.name in ["test4", "test5", "test8"]
]
expected_time_savings = 9.0
time_savings = log_time_savings(
tests, prioritized_tests, is_serial_test_fn=never_serial, num_procs=2
)
self.assertEqual(
time_savings, expected_time_savings, "Received an unexpected time savings"
)
def test_compute_prioritization_time_savings_with_multiple_threads_and_many_prioritized_tests(
self,
) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=2.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test6", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test7", shard=1, num_shards=2, time=5.0),
]
prioritized_tests = [
test for test in tests if test.name in ["test2", "test3", "test7"]
]
# Drawing out the math here since this is a complicated example
# Logic for original execution assuming 2 procs
# Test | Proc 1 | Proc 2
# test1 | 4 |
# test2 | | 3
# test3 | | 2
# test4 | 3 |
# test5 | | 4
# test6 | 3 |
# test7 | | 5 <- starts at time 9 ( 3 + 2 + 4)
# Logic for new execution's prioritized pool:
# Test | Proc 1 | Proc 2
# test3 | 2 |
# test4 | | 3
# test7 | 5 | <- now starts at time 2
# Time savings = 9 - 2 = 7
expected_time_savings = 7.0
time_savings = log_time_savings(
tests, prioritized_tests, is_serial_test_fn=never_serial, num_procs=2
)
self.assertEqual(
time_savings, expected_time_savings, "Received an unexpected time savings"
)
pass
def test_compute_prioritization_time_savings_with_serialized_test(self) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=7.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=5.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=2.0),
ShardedTest(name="test6", shard=1, num_shards=2, time=1.0),
]
prioritized_tests = [test for test in tests if test.name in ["test3", "test6"]]
def serialized(test: str) -> bool:
return test in ["test4", "test6"]
expected_time_savings = 8.0
time_savings = log_time_savings(
tests, prioritized_tests, is_serial_test_fn=serialized, num_procs=2
)
self.assertEqual(
time_savings, expected_time_savings, "Received an unexpected time savings"
)
pass
if __name__ == "__main__":
unittest.main()

View File

@ -1,3 +1,4 @@
import heapq
import json
import math
import os
@ -195,6 +196,67 @@ def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
return valid_tests
class PoolTimes:
def __init__(self, num_procs: int) -> None:
self.pool_times = [0.0 for _ in range(num_procs)]
self.serial_times = 0.0
def next_test_start_time(self, serial: bool) -> float:
if serial:
# Serial tests are run after all parallel tests complete
return max(self.pool_times) + self.serial_times
return self.pool_times[0]
def schedule_test(self, test: ShardedTest, serial: bool) -> None:
if serial:
self.serial_times += test.get_time()
else:
# pool_times[0] is always the thread with the least amount of time scheduled
heapq.heappushpop(self.pool_times, self.pool_times[0] + test.get_time())
def log_time_savings(
selected_tests: List[ShardedTest],
prioritized_tests: List[ShardedTest],
is_serial_test_fn: Callable[[str], bool],
num_procs: int = NUM_PROCS, # make this customizable for testing
) -> float:
# The tests will be run in [num_procs] parallel threads, so we assume each test
# is allocated to the thread that'll free up first.
# This isn't an exact match (since other factors could change which thread
# pool a test gets scheduled on) but it's a good approximation.
# Simulates the scheduled tests on each thread pool
default_pool = PoolTimes(num_procs) # originally scheduled run
prioritized_pool = PoolTimes(num_procs) # run for prioritized tests
max_time_savings_sec = 0.0
# It's easier to look up prioritized tests by name
prioritized_test_names = {test.name for test in prioritized_tests}
for test in selected_tests:
serial = is_serial_test_fn(test.name)
if test.name in prioritized_test_names:
# Successive tests will always have a greater time savings
max_time_savings_sec = default_pool.next_test_start_time(
serial
) - prioritized_pool.next_test_start_time(serial)
# "schedule" this test on the prioritized pool to get time savings for future prioritized tests
prioritized_pool.schedule_test(test, serial)
# always schedule on the default pool to know what the unprioritized timeline would've looked like
default_pool.schedule_test(test, serial)
print(
f"Prioritized tests will run about {duration_to_str(max_time_savings_sec)} sooner than they would've otherwise"
)
# Return value used by tests
return max_time_savings_sec
def get_reordered_tests(
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
@ -228,18 +290,11 @@ def get_reordered_tests(
bring_to_front = []
the_rest = []
test_time_for_regular_tests_so_far = 0.0
# how much sooner did we run prioritized tests compared to a naive ordering
time_savings_sec = 0.0
for test in tests:
if test.name in prioritized_tests:
bring_to_front.append(test)
# Calculate approx time saved by reordering
time_savings_sec = test_time_for_regular_tests_so_far
else:
the_rest.append(test)
test_time_for_regular_tests_so_far += test.get_time()
if len(tests) != len(bring_to_front) + len(the_rest):
print(
@ -252,9 +307,6 @@ def get_reordered_tests(
if bring_to_front:
test_cnt_str = pluralize(len(tests), "test")
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
print(
f"Prioritized tests estimated to run up to {duration_to_str(time_savings_sec)} sooner than they would've otherwise"
)
prioritized_test_names = [t.name for t in bring_to_front]
print(f"Prioritized: {prioritized_test_names}")