mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow sharding for distributed tests
Addresses my mistake introduced in https://github.com/pytorch/pytorch/pull/76536#issuecomment-1112657429 Also allows for sharding 1 in run_test.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/76570 Approved by: https://github.com/jeffdaily, https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe1968dea0
commit
0708630d9f
@ -319,10 +319,10 @@ test_vulkan() {
|
|||||||
|
|
||||||
test_distributed() {
|
test_distributed() {
|
||||||
echo "Testing distributed python tests"
|
echo "Testing distributed python tests"
|
||||||
time python test/run_test.py --distributed-tests --verbose
|
time python test/run_test.py --distributed-tests --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
|
||||||
assert_git_not_dirty
|
assert_git_not_dirty
|
||||||
|
|
||||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
if [[ "$BUILD_ENVIRONMENT" == *cuda* && "$SHARD_NUMBER" == 1 ]]; then
|
||||||
echo "Testing distributed C++ tests"
|
echo "Testing distributed C++ tests"
|
||||||
ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR"
|
ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR"
|
||||||
ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR"
|
ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR"
|
||||||
@ -564,6 +564,12 @@ elif [[ "${BUILD_ENVIRONMENT}" == *jit_legacy-test || "${JOB_BASE_NAME}" == *jit
|
|||||||
elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
|
elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
|
||||||
# TODO: run some C++ tests
|
# TODO: run some C++ tests
|
||||||
echo "no-op at the moment"
|
echo "no-op at the moment"
|
||||||
|
elif [[ "${BUILD_ENVIRONMENT}" == *distributed* || "${JOB_BASE_NAME}" == *distributed* ]]; then
|
||||||
|
test_distributed
|
||||||
|
# Only run RPC C++ tests on the first shard
|
||||||
|
if [[ "${SHARD_NUMBER}" == 1 ]]; then
|
||||||
|
test_rpc
|
||||||
|
fi
|
||||||
elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
|
elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||||
test_without_numpy
|
test_without_numpy
|
||||||
install_torchvision
|
install_torchvision
|
||||||
@ -585,9 +591,6 @@ elif [[ "${BUILD_ENVIRONMENT}" == *vulkan* ]]; then
|
|||||||
echo "no-op at the moment"
|
echo "no-op at the moment"
|
||||||
elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||||
test_bazel
|
test_bazel
|
||||||
elif [[ "${BUILD_ENVIRONMENT}" == *distributed* || "${JOB_BASE_NAME}" == *distributed* ]]; then
|
|
||||||
test_distributed
|
|
||||||
test_rpc
|
|
||||||
elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then
|
elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then
|
||||||
test_libtorch
|
test_libtorch
|
||||||
elif [[ "${TEST_CONFIG}" = docs_test ]]; then
|
elif [[ "${TEST_CONFIG}" = docs_test ]]; then
|
||||||
|
@ -65,6 +65,29 @@ class TestCalculateShards(unittest.TestCase):
|
|||||||
expected_shards, calculate_shards(2, self.tests, self.test_times)
|
expected_shards, calculate_shards(2, self.tests, self.test_times)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_calculate_1_shard_with_complete_test_times(self) -> None:
|
||||||
|
expected_shards = [
|
||||||
|
(
|
||||||
|
118.31,
|
||||||
|
[
|
||||||
|
"super_long_test",
|
||||||
|
"long_test1",
|
||||||
|
"long_test2",
|
||||||
|
"normal_test1",
|
||||||
|
"normal_test2",
|
||||||
|
"normal_test3",
|
||||||
|
"short_test1",
|
||||||
|
"short_test2",
|
||||||
|
"short_test3",
|
||||||
|
"short_test4",
|
||||||
|
"short_test5",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assert_shards_equal(
|
||||||
|
expected_shards, calculate_shards(1, self.tests, self.test_times)
|
||||||
|
)
|
||||||
|
|
||||||
def test_calculate_5_shards_with_complete_test_times(self) -> None:
|
def test_calculate_5_shards_with_complete_test_times(self) -> None:
|
||||||
expected_shards = [
|
expected_shards = [
|
||||||
(55.0, ["super_long_test"]),
|
(55.0, ["super_long_test"]),
|
||||||
|
@ -194,10 +194,14 @@ def _query_changed_test_files() -> List[str]:
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
# Get sharded test allocation based on historic S3 data.
|
||||||
def get_shard_based_on_S3(
|
def get_shard_based_on_S3(
|
||||||
which_shard: int, num_shards: int, tests: List[str], test_times_file: str
|
which_shard: int, num_shards: int, tests: List[str], test_times_file: str
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Get sharded test allocation based on historic S3 data."""
|
# Short circuit and don't do any work if there's only 1 shard
|
||||||
|
if num_shards == 1:
|
||||||
|
return tests
|
||||||
|
|
||||||
jobs_to_times = _query_past_job_times(test_times_file)
|
jobs_to_times = _query_past_job_times(test_times_file)
|
||||||
|
|
||||||
# Got no stats from S3, returning early to save runtime
|
# Got no stats from S3, returning early to save runtime
|
||||||
|
Reference in New Issue
Block a user