mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Allow sharding for distributed tests
Addresses my mistake introduced in https://github.com/pytorch/pytorch/pull/76536#issuecomment-1112657429 Also allows for sharding 1 in run_test.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/76570 Approved by: https://github.com/jeffdaily, https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe1968dea0
commit
0708630d9f
@ -319,10 +319,10 @@ test_vulkan() {
|
||||
|
||||
test_distributed() {
|
||||
echo "Testing distributed python tests"
|
||||
time python test/run_test.py --distributed-tests --verbose
|
||||
time python test/run_test.py --distributed-tests --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
|
||||
assert_git_not_dirty
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* && "$SHARD_NUMBER" == 1 ]]; then
|
||||
echo "Testing distributed C++ tests"
|
||||
ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR"
|
||||
ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR"
|
||||
@ -564,6 +564,12 @@ elif [[ "${BUILD_ENVIRONMENT}" == *jit_legacy-test || "${JOB_BASE_NAME}" == *jit
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
|
||||
# TODO: run some C++ tests
|
||||
echo "no-op at the moment"
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *distributed* || "${JOB_BASE_NAME}" == *distributed* ]]; then
|
||||
test_distributed
|
||||
# Only run RPC C++ tests on the first shard
|
||||
if [[ "${SHARD_NUMBER}" == 1 ]]; then
|
||||
test_rpc
|
||||
fi
|
||||
elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||
test_without_numpy
|
||||
install_torchvision
|
||||
@ -585,9 +591,6 @@ elif [[ "${BUILD_ENVIRONMENT}" == *vulkan* ]]; then
|
||||
echo "no-op at the moment"
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
test_bazel
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *distributed* || "${JOB_BASE_NAME}" == *distributed* ]]; then
|
||||
test_distributed
|
||||
test_rpc
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then
|
||||
test_libtorch
|
||||
elif [[ "${TEST_CONFIG}" = docs_test ]]; then
|
||||
|
@ -65,6 +65,29 @@ class TestCalculateShards(unittest.TestCase):
|
||||
expected_shards, calculate_shards(2, self.tests, self.test_times)
|
||||
)
|
||||
|
||||
def test_calculate_1_shard_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(
|
||||
118.31,
|
||||
[
|
||||
"super_long_test",
|
||||
"long_test1",
|
||||
"long_test2",
|
||||
"normal_test1",
|
||||
"normal_test2",
|
||||
"normal_test3",
|
||||
"short_test1",
|
||||
"short_test2",
|
||||
"short_test3",
|
||||
"short_test4",
|
||||
"short_test5",
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(1, self.tests, self.test_times)
|
||||
)
|
||||
|
||||
def test_calculate_5_shards_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(55.0, ["super_long_test"]),
|
||||
|
@ -194,10 +194,14 @@ def _query_changed_test_files() -> List[str]:
|
||||
return lines
|
||||
|
||||
|
||||
# Get sharded test allocation based on historic S3 data.
|
||||
def get_shard_based_on_S3(
|
||||
which_shard: int, num_shards: int, tests: List[str], test_times_file: str
|
||||
) -> List[str]:
|
||||
"""Get sharded test allocation based on historic S3 data."""
|
||||
# Short circuit and don't do any work if there's only 1 shard
|
||||
if num_shards == 1:
|
||||
return tests
|
||||
|
||||
jobs_to_times = _query_past_job_times(test_times_file)
|
||||
|
||||
# Got no stats from S3, returning early to save runtime
|
||||
|
Reference in New Issue
Block a user