mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve time savings calculation math for test reordering (#102411)
Use a more accurate method that accounts for tests being run in parallel Right now we still log results to the console, but later it'll get logged to Rockset for better tracking Pull Request resolved: https://github.com/pytorch/pytorch/pull/102411 Approved by: https://github.com/huydhn, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
693114c0a2
commit
c84f246c83
@ -1,3 +1,4 @@
|
||||
import heapq
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@ -195,6 +196,67 @@ def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
|
||||
return valid_tests
|
||||
|
||||
|
||||
class PoolTimes:
|
||||
def __init__(self, num_procs: int) -> None:
|
||||
self.pool_times = [0.0 for _ in range(num_procs)]
|
||||
self.serial_times = 0.0
|
||||
|
||||
def next_test_start_time(self, serial: bool) -> float:
|
||||
if serial:
|
||||
# Serial tests are run after all parallel tests complete
|
||||
return max(self.pool_times) + self.serial_times
|
||||
|
||||
return self.pool_times[0]
|
||||
|
||||
def schedule_test(self, test: ShardedTest, serial: bool) -> None:
|
||||
if serial:
|
||||
self.serial_times += test.get_time()
|
||||
else:
|
||||
# pool_times[0] is always the thread with the least amount of time scheduled
|
||||
heapq.heappushpop(self.pool_times, self.pool_times[0] + test.get_time())
|
||||
|
||||
|
||||
def log_time_savings(
|
||||
selected_tests: List[ShardedTest],
|
||||
prioritized_tests: List[ShardedTest],
|
||||
is_serial_test_fn: Callable[[str], bool],
|
||||
num_procs: int = NUM_PROCS, # make this customizable for testing
|
||||
) -> float:
|
||||
# The tests will be run in [num_procs] parallel threads, so we assume each test
|
||||
# is allocated to the thread that'll free up first.
|
||||
# This isn't an exact match (since other factors could change which thread
|
||||
# pool a test gets scheduled on) but it's a good approximation.
|
||||
|
||||
# Simulates the scheduled tests on each thread pool
|
||||
default_pool = PoolTimes(num_procs) # originally scheduled run
|
||||
prioritized_pool = PoolTimes(num_procs) # run for prioritized tests
|
||||
max_time_savings_sec = 0.0
|
||||
|
||||
# It's easier to look up prioritized tests by name
|
||||
prioritized_test_names = {test.name for test in prioritized_tests}
|
||||
|
||||
for test in selected_tests:
|
||||
serial = is_serial_test_fn(test.name)
|
||||
if test.name in prioritized_test_names:
|
||||
# Successive tests will always have a greater time savings
|
||||
max_time_savings_sec = default_pool.next_test_start_time(
|
||||
serial
|
||||
) - prioritized_pool.next_test_start_time(serial)
|
||||
|
||||
# "schedule" this test on the prioritized pool to get time savings for future prioritized tests
|
||||
prioritized_pool.schedule_test(test, serial)
|
||||
|
||||
# always schedule on the default pool to know what the unprioritized timeline would've looked like
|
||||
default_pool.schedule_test(test, serial)
|
||||
|
||||
print(
|
||||
f"Prioritized tests will run about {duration_to_str(max_time_savings_sec)} sooner than they would've otherwise"
|
||||
)
|
||||
|
||||
# Return value used by tests
|
||||
return max_time_savings_sec
|
||||
|
||||
|
||||
def get_reordered_tests(
|
||||
tests: List[ShardedTest],
|
||||
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
||||
@ -228,18 +290,11 @@ def get_reordered_tests(
|
||||
bring_to_front = []
|
||||
the_rest = []
|
||||
|
||||
test_time_for_regular_tests_so_far = 0.0
|
||||
# how much sooner did we run prioritized tests compared to a naive ordering
|
||||
time_savings_sec = 0.0
|
||||
|
||||
for test in tests:
|
||||
if test.name in prioritized_tests:
|
||||
bring_to_front.append(test)
|
||||
# Calculate approx time saved by reordering
|
||||
time_savings_sec = test_time_for_regular_tests_so_far
|
||||
else:
|
||||
the_rest.append(test)
|
||||
test_time_for_regular_tests_so_far += test.get_time()
|
||||
|
||||
if len(tests) != len(bring_to_front) + len(the_rest):
|
||||
print(
|
||||
@ -252,9 +307,6 @@ def get_reordered_tests(
|
||||
if bring_to_front:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
|
||||
print(
|
||||
f"Prioritized tests estimated to run up to {duration_to_str(time_savings_sec)} sooner than they would've otherwise"
|
||||
)
|
||||
|
||||
prioritized_test_names = [t.name for t in bring_to_front]
|
||||
print(f"Prioritized: {prioritized_test_names}")
|
||||
|
Reference in New Issue
Block a user