mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix round robin sharding (#121022)
Fix round robin sharding when there are no test times and sort_by_time=False Adds more tests to test_test_selections for sort_by_time=False Adds more checks to test_split_shards_random for serial/parallel ordering + ordering of tests Refactoring of dup code Tested locally by running `python test/run_test.py --shard 3 5` with no test times downloaded and checked that it wasn't an empty list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121022 Approved by: https://github.com/huydhn, https://github.com/osalpekar
This commit is contained in:
committed by
PyTorch MergeBot
parent
9d83f9dc0e
commit
effdea5fc6
@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import pathlib
|
import pathlib
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
@ -72,6 +73,109 @@ class TestCalculateShards(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(expected[0], actual[0])
|
self.assertAlmostEqual(expected[0], actual[0])
|
||||||
self.assertListEqual(expected[1], actual[1])
|
self.assertListEqual(expected[1], actual[1])
|
||||||
|
|
||||||
|
def test_no_times(self) -> None:
|
||||||
|
# Check that round robin sharding is used when no times are provided
|
||||||
|
expected_shards = [
|
||||||
|
(
|
||||||
|
0.0,
|
||||||
|
[
|
||||||
|
ShardedTest(
|
||||||
|
test="super_long_test", shard=1, num_shards=1, time=None
|
||||||
|
),
|
||||||
|
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="short_test1", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
0.0,
|
||||||
|
[
|
||||||
|
ShardedTest(test="long_test1", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assert_shards_equal(
|
||||||
|
expected_shards,
|
||||||
|
calculate_shards(2, self.tests, {}, {}, sort_by_time=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_some_times_with_not_sort_by_time(self) -> None:
|
||||||
|
expected_shards = [
|
||||||
|
(
|
||||||
|
300.0,
|
||||||
|
[
|
||||||
|
ShardedTest(test="test_1", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="test_3", shard=1, num_shards=1, time=300),
|
||||||
|
ShardedTest(test="test_4", shard=1, num_shards=1, time=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
400.0,
|
||||||
|
[
|
||||||
|
ShardedTest(test="test_2", shard=1, num_shards=1, time=400),
|
||||||
|
ShardedTest(test="test_5", shard=1, num_shards=1, time=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assert_shards_equal(
|
||||||
|
expected_shards,
|
||||||
|
calculate_shards(
|
||||||
|
2,
|
||||||
|
[
|
||||||
|
TestRun("test_1"),
|
||||||
|
TestRun("test_2"),
|
||||||
|
TestRun("test_3"),
|
||||||
|
TestRun("test_4"),
|
||||||
|
TestRun("test_5"),
|
||||||
|
],
|
||||||
|
{"test_2": 400, "test_3": 300},
|
||||||
|
{},
|
||||||
|
sort_by_time=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_serial_parallel_interleaving(self) -> None:
|
||||||
|
expected_shards = [
|
||||||
|
(
|
||||||
|
400.0,
|
||||||
|
[
|
||||||
|
ShardedTest(test="test_1", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="test_4", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="test_5", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="test_2", shard=1, num_shards=1, time=400),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
300.0,
|
||||||
|
[
|
||||||
|
ShardedTest(test="test_3", shard=1, num_shards=1, time=300),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assert_shards_equal(
|
||||||
|
expected_shards,
|
||||||
|
calculate_shards(
|
||||||
|
2,
|
||||||
|
[
|
||||||
|
TestRun("test_1"),
|
||||||
|
TestRun("test_2"),
|
||||||
|
TestRun("test_3"),
|
||||||
|
TestRun("test_4"),
|
||||||
|
TestRun("test_5"),
|
||||||
|
],
|
||||||
|
{"test_2": 400, "test_3": 300},
|
||||||
|
{},
|
||||||
|
must_serial=lambda x: x in ["test_1", "test_3"],
|
||||||
|
sort_by_time=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def test_calculate_2_shards_with_complete_test_times(self) -> None:
|
def test_calculate_2_shards_with_complete_test_times(self) -> None:
|
||||||
expected_shards = [
|
expected_shards = [
|
||||||
(
|
(
|
||||||
@ -213,8 +317,7 @@ class TestCalculateShards(unittest.TestCase):
|
|||||||
22.0,
|
22.0,
|
||||||
[
|
[
|
||||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
||||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -228,7 +331,8 @@ class TestCalculateShards(unittest.TestCase):
|
|||||||
1.0,
|
1.0,
|
||||||
[
|
[
|
||||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||||
|
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -330,23 +434,27 @@ class TestCalculateShards(unittest.TestCase):
|
|||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
num_shards = random.randint(1, 10)
|
num_shards = random.randint(1, 10)
|
||||||
num_tests = random.randint(1, 100)
|
num_tests = random.randint(1, 100)
|
||||||
|
test_names = [str(i) for i in range(num_tests)]
|
||||||
|
tests = [TestRun(x) for x in test_names]
|
||||||
|
serial = [x for x in test_names if random.randint(0, 1) == 0]
|
||||||
|
has_times = [x for x in test_names if random.randint(0, 1) == 0]
|
||||||
random_times: Dict[str, float] = {
|
random_times: Dict[str, float] = {
|
||||||
str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests)
|
i: random.randint(0, THRESHOLD * 10) for i in has_times
|
||||||
}
|
}
|
||||||
serial = [str(i) for i in range(num_tests) if random.randint(0, 1) == 0]
|
sort_by_time = random.randint(0, 1) == 0
|
||||||
|
|
||||||
shards = calculate_shards(
|
shards = calculate_shards(
|
||||||
num_shards,
|
num_shards,
|
||||||
[TestRun(t) for t in random_times.keys()],
|
tests,
|
||||||
random_times,
|
random_times,
|
||||||
None,
|
None,
|
||||||
must_serial=lambda x: x in serial,
|
must_serial=lambda x: x in serial,
|
||||||
sort_by_time=random.randint(0, 1) == 0,
|
sort_by_time=sort_by_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
times = [x[0] for x in shards]
|
times = [x[0] for x in shards]
|
||||||
max_diff = max(times) - min(times)
|
max_diff = max(times) - min(times)
|
||||||
self.assertTrue(max_diff <= THRESHOLD)
|
self.assertTrue(max_diff <= THRESHOLD + (num_tests - len(has_times)) * 60)
|
||||||
|
|
||||||
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
|
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
|
||||||
for _, sharded_tests in shards:
|
for _, sharded_tests in shards:
|
||||||
@ -354,19 +462,58 @@ class TestCalculateShards(unittest.TestCase):
|
|||||||
all_sharded_tests[sharded_test.name].append(sharded_test)
|
all_sharded_tests[sharded_test.name].append(sharded_test)
|
||||||
|
|
||||||
# Check that all test files are represented in the shards
|
# Check that all test files are represented in the shards
|
||||||
self.assertListEqual(
|
self.assertListEqual(sorted(test_names), sorted(all_sharded_tests.keys()))
|
||||||
sorted(random_times.keys()), sorted(all_sharded_tests.keys())
|
|
||||||
)
|
|
||||||
# Check that for each test file, the pytest shards' times adds up to
|
# Check that for each test file, the pytest shards' times adds up to
|
||||||
# original and all shards are present
|
# original and all shards are present
|
||||||
for test, sharded_tests in all_sharded_tests.items():
|
for test, sharded_tests in all_sharded_tests.items():
|
||||||
|
if random_times.get(test) is None:
|
||||||
|
self.assertTrue(len(sharded_tests) == 1)
|
||||||
|
self.assertTrue(sharded_tests[0].time is None)
|
||||||
|
else:
|
||||||
|
# x.time is not None because of the above check
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
random_times[test], sum(x.time or 0 for x in sharded_tests)
|
random_times[test], sum(x.time for x in sharded_tests) # type: ignore[misc]
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(range(sharded_tests[0].num_shards)),
|
list(range(sharded_tests[0].num_shards)),
|
||||||
sorted(x.shard - 1 for x in sharded_tests),
|
sorted(x.shard - 1 for x in sharded_tests),
|
||||||
)
|
)
|
||||||
|
# Tests without times are serial
|
||||||
|
true_serial = set(serial + [x for x in test_names if x not in has_times])
|
||||||
|
# Check that sort_by_time is respected
|
||||||
|
if sort_by_time:
|
||||||
|
|
||||||
|
def comparator(a: ShardedTest, b: ShardedTest) -> int:
|
||||||
|
# serial comes first
|
||||||
|
if a.name in true_serial and b.name not in true_serial:
|
||||||
|
return -1
|
||||||
|
if a.name not in true_serial and b.name in true_serial:
|
||||||
|
return 1
|
||||||
|
# known test times come first
|
||||||
|
if a.time is not None and b.time is None:
|
||||||
|
return -1
|
||||||
|
if a.time is None and b.time is not None:
|
||||||
|
return 1
|
||||||
|
if a.time == b.time:
|
||||||
|
return 0
|
||||||
|
# not None due to the above checks
|
||||||
|
return -1 if a.time > b.time else 1 # type: ignore[operator]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def comparator(a: ShardedTest, b: ShardedTest) -> int:
|
||||||
|
# serial comes first
|
||||||
|
if a.name in true_serial and b.name not in true_serial:
|
||||||
|
return -1
|
||||||
|
if a.name not in true_serial and b.name in true_serial:
|
||||||
|
return 1
|
||||||
|
return test_names.index(a.name) - test_names.index(b.name)
|
||||||
|
|
||||||
|
for _, sharded_tests in shards:
|
||||||
|
self.assertListEqual(
|
||||||
|
sorted(sharded_tests, key=functools.cmp_to_key(comparator)),
|
||||||
|
sharded_tests,
|
||||||
|
)
|
||||||
|
|
||||||
def test_calculate_2_shards_against_optimal_shards(self) -> None:
|
def test_calculate_2_shards_against_optimal_shards(self) -> None:
|
||||||
random.seed(120)
|
random.seed(120)
|
||||||
|
@ -293,8 +293,8 @@ class ShardedTest:
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.test} {self.shard}/{self.num_shards}"
|
return f"{self.test} {self.shard}/{self.num_shards}"
|
||||||
|
|
||||||
def get_time(self) -> float:
|
def get_time(self, default: float = 0) -> float:
|
||||||
return self.time or 0
|
return self.time if self.time is not None else default
|
||||||
|
|
||||||
def get_pytest_args(self) -> List[str]:
|
def get_pytest_args(self) -> List[str]:
|
||||||
filter = self.test.get_pytest_filter()
|
filter = self.test.get_pytest_filter()
|
||||||
|
@ -20,6 +20,7 @@ IS_ROCM = os.path.exists("/opt/rocm")
|
|||||||
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
|
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
|
||||||
NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2
|
NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2
|
||||||
THRESHOLD = 60 * 10 # 10 minutes
|
THRESHOLD = 60 * 10 # 10 minutes
|
||||||
|
DEFAULT_TIME = 60 # if no test times available for the test, assume it takes 60s
|
||||||
|
|
||||||
# See Note [ROCm parallel CI testing]
|
# See Note [ROCm parallel CI testing]
|
||||||
# Special logic for ROCm GHA runners to query number of GPUs available.
|
# Special logic for ROCm GHA runners to query number of GPUs available.
|
||||||
@ -48,12 +49,13 @@ class ShardJob:
|
|||||||
self.serial: List[ShardedTest] = []
|
self.serial: List[ShardedTest] = []
|
||||||
self.parallel: List[ShardedTest] = []
|
self.parallel: List[ShardedTest] = []
|
||||||
|
|
||||||
def get_total_time(self) -> float:
|
def get_total_time(self, default: float = 0.0) -> float:
|
||||||
|
"""Default is the value for which to substitute if a test has no time"""
|
||||||
procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)]
|
procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)]
|
||||||
for test in self.parallel:
|
for test in self.parallel:
|
||||||
min_index = procs.index(min(procs))
|
min_index = procs.index(min(procs))
|
||||||
procs[min_index] += test.get_time()
|
procs[min_index] += test.get_time(default)
|
||||||
time = max(procs) + sum(test.get_time() for test in self.serial)
|
time = max(procs) + sum(test.get_time(default) for test in self.serial)
|
||||||
return time
|
return time
|
||||||
|
|
||||||
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
|
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
|
||||||
@ -86,6 +88,8 @@ def get_duration(
|
|||||||
test_file_times: Dict[str, float],
|
test_file_times: Dict[str, float],
|
||||||
test_class_times: Dict[str, Dict[str, float]],
|
test_class_times: Dict[str, Dict[str, float]],
|
||||||
) -> Optional[float]:
|
) -> Optional[float]:
|
||||||
|
"""Calculate the time for a TestRun based on the given test_file_times and
|
||||||
|
test_class_times. Returns None if the time is unknown."""
|
||||||
file_duration = test_file_times.get(test.test_file, None)
|
file_duration = test_file_times.get(test.test_file, None)
|
||||||
if test.is_full_file():
|
if test.is_full_file():
|
||||||
return file_duration
|
return file_duration
|
||||||
@ -123,65 +127,46 @@ def get_duration(
|
|||||||
|
|
||||||
def shard(
|
def shard(
|
||||||
sharded_jobs: List[ShardJob],
|
sharded_jobs: List[ShardJob],
|
||||||
tests: Sequence[TestRun],
|
pytest_sharded_tests: Sequence[ShardedTest],
|
||||||
test_file_times: Dict[str, float],
|
|
||||||
test_class_times: Dict[str, Dict[str, float]],
|
|
||||||
estimated_time_limit: Optional[float] = None,
|
estimated_time_limit: Optional[float] = None,
|
||||||
sort_by_time: bool = True,
|
|
||||||
serial: bool = False,
|
serial: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if len(sharded_jobs) == 0:
|
|
||||||
assert len(tests) == 0, "No shards provided but there are tests to shard"
|
|
||||||
return
|
|
||||||
# Modifies sharded_jobs in place
|
# Modifies sharded_jobs in place
|
||||||
known_tests = tests
|
if len(sharded_jobs) == 0:
|
||||||
unknown_tests = []
|
|
||||||
if sort_by_time:
|
|
||||||
known_tests = [
|
|
||||||
x
|
|
||||||
for x in tests
|
|
||||||
if get_duration(x, test_file_times, test_class_times) is not None
|
|
||||||
]
|
|
||||||
unknown_tests = [x for x in tests if x not in known_tests]
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
unknown_tests == [] or serial
|
len(pytest_sharded_tests) == 0
|
||||||
), f"Attmempting to parallelize unknown tests {unknown_tests}"
|
), "No shards provided but there are tests to shard"
|
||||||
del tests
|
return
|
||||||
|
|
||||||
known_tests = get_with_pytest_shard(known_tests, test_file_times, test_class_times)
|
def _get_min_sharded_job(sharded_jobs: List[ShardJob]) -> ShardJob:
|
||||||
|
return min(sharded_jobs, key=lambda j: j.get_total_time(default=DEFAULT_TIME))
|
||||||
|
|
||||||
if sort_by_time:
|
def _shard_serial(
|
||||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
tests: Sequence[ShardedTest], sharded_jobs: List[ShardJob]
|
||||||
|
) -> None:
|
||||||
def _shard_serial(tests: List[ShardedTest], sharded_jobs: List[ShardJob]) -> None:
|
|
||||||
assert estimated_time_limit is not None, "Estimated time limit must be provided"
|
assert estimated_time_limit is not None, "Estimated time limit must be provided"
|
||||||
new_sharded_jobs = sharded_jobs
|
new_sharded_jobs = sharded_jobs
|
||||||
for test in tests:
|
for test in tests:
|
||||||
if (
|
if (
|
||||||
len(sharded_jobs) > 1
|
len(sharded_jobs) > 1
|
||||||
and sharded_jobs[-1].get_total_time() > estimated_time_limit
|
and sharded_jobs[-1].get_total_time(default=DEFAULT_TIME)
|
||||||
|
> estimated_time_limit
|
||||||
):
|
):
|
||||||
new_sharded_jobs = sharded_jobs[:-1]
|
new_sharded_jobs = sharded_jobs[:-1]
|
||||||
min_sharded_job = min(new_sharded_jobs, key=lambda j: j.get_total_time())
|
min_sharded_job = _get_min_sharded_job(new_sharded_jobs)
|
||||||
min_sharded_job.serial.append(test)
|
min_sharded_job.serial.append(test)
|
||||||
|
|
||||||
def _shard_parallel(tests: List[ShardedTest], sharded_jobs: List[ShardJob]) -> None:
|
def _shard_parallel(
|
||||||
|
tests: Sequence[ShardedTest], sharded_jobs: List[ShardJob]
|
||||||
|
) -> None:
|
||||||
for test in tests:
|
for test in tests:
|
||||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
min_sharded_job = _get_min_sharded_job(sharded_jobs)
|
||||||
min_sharded_job.parallel.append(test)
|
min_sharded_job.parallel.append(test)
|
||||||
|
|
||||||
if serial:
|
if serial:
|
||||||
_shard_serial(known_tests, sharded_jobs)
|
_shard_serial(pytest_sharded_tests, sharded_jobs)
|
||||||
else:
|
else:
|
||||||
_shard_parallel(known_tests, sharded_jobs)
|
_shard_parallel(pytest_sharded_tests, sharded_jobs)
|
||||||
|
|
||||||
# Round robin the unknown jobs starting with the smallest shard
|
|
||||||
num_shards = len(sharded_jobs)
|
|
||||||
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
|
|
||||||
for unknown_test in unknown_tests:
|
|
||||||
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
|
||||||
index = (index + 1) % num_shards
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -196,22 +181,36 @@ def calculate_shards(
|
|||||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||||
must_serial = must_serial or (lambda x: True)
|
must_serial = must_serial or (lambda x: True)
|
||||||
test_class_times = test_class_times or {}
|
test_class_times = test_class_times or {}
|
||||||
|
|
||||||
|
# Divide tests into pytest shards
|
||||||
|
if sort_by_time:
|
||||||
|
known_tests = [
|
||||||
|
x
|
||||||
|
for x in tests
|
||||||
|
if get_duration(x, test_file_times, test_class_times) is not None
|
||||||
|
]
|
||||||
|
unknown_tests = [x for x in tests if x not in known_tests]
|
||||||
|
|
||||||
|
pytest_sharded_tests = sorted(
|
||||||
|
get_with_pytest_shard(known_tests, test_file_times, test_class_times),
|
||||||
|
key=lambda j: j.get_time(),
|
||||||
|
reverse=True,
|
||||||
|
) + get_with_pytest_shard(unknown_tests, test_file_times, test_class_times)
|
||||||
|
else:
|
||||||
|
pytest_sharded_tests = get_with_pytest_shard(
|
||||||
|
tests, test_file_times, test_class_times
|
||||||
|
)
|
||||||
|
del tests
|
||||||
|
|
||||||
serial_tests = [
|
serial_tests = [
|
||||||
test
|
test
|
||||||
for test in tests
|
for test in pytest_sharded_tests
|
||||||
if get_duration(test, test_file_times, test_class_times) is None
|
if must_serial(test.name) or test.time is None
|
||||||
or must_serial(test.test_file)
|
|
||||||
]
|
]
|
||||||
parallel_tests = [test for test in tests if test not in serial_tests]
|
parallel_tests = [test for test in pytest_sharded_tests if test not in serial_tests]
|
||||||
|
|
||||||
serial_time = sum(
|
serial_time = sum(test.get_time(DEFAULT_TIME) for test in serial_tests)
|
||||||
get_duration(test, test_file_times, test_class_times) or 0
|
parallel_time = sum(test.get_time(DEFAULT_TIME) for test in parallel_tests)
|
||||||
for test in serial_tests
|
|
||||||
)
|
|
||||||
parallel_time = sum(
|
|
||||||
get_duration(test, test_file_times, test_class_times) or 0
|
|
||||||
for test in parallel_tests
|
|
||||||
)
|
|
||||||
total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC
|
total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC
|
||||||
estimated_time_per_shard = total_time / num_shards
|
estimated_time_per_shard = total_time / num_shards
|
||||||
# Separate serial tests from parallel tests as much as possible to maximize
|
# Separate serial tests from parallel tests as much as possible to maximize
|
||||||
@ -237,20 +236,14 @@ def calculate_shards(
|
|||||||
|
|
||||||
sharded_jobs = [ShardJob() for _ in range(num_shards)]
|
sharded_jobs = [ShardJob() for _ in range(num_shards)]
|
||||||
shard(
|
shard(
|
||||||
sharded_jobs[:num_serial_shards],
|
sharded_jobs=sharded_jobs[:num_serial_shards],
|
||||||
serial_tests,
|
pytest_sharded_tests=serial_tests,
|
||||||
test_file_times,
|
|
||||||
test_class_times,
|
|
||||||
estimated_time_limit=estimated_time_limit,
|
estimated_time_limit=estimated_time_limit,
|
||||||
sort_by_time=sort_by_time,
|
|
||||||
serial=True,
|
serial=True,
|
||||||
)
|
)
|
||||||
shard(
|
shard(
|
||||||
sharded_jobs,
|
sharded_jobs=sharded_jobs,
|
||||||
parallel_tests,
|
pytest_sharded_tests=parallel_tests,
|
||||||
test_file_times,
|
|
||||||
test_class_times,
|
|
||||||
sort_by_time=sort_by_time,
|
|
||||||
serial=False,
|
serial=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user