mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[test stats] use published test stats for sharding (#81116)
Use the nightly-published test stats to perform sharding, instead of calculating it in every build job. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81116 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
fb93c3988a
commit
9f58d5d7ce
2
.github/workflows/_linux-build.yml
vendored
2
.github/workflows/_linux-build.yml
vendored
@ -135,7 +135,7 @@ jobs:
|
||||
- name: Archive artifacts into zip
|
||||
if: inputs.build-generates-artifacts
|
||||
run: |
|
||||
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
|
||||
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin
|
||||
|
||||
- name: Store PyTorch Build Artifacts on S3
|
||||
uses: seemethere/upload-artifact-s3@v5
|
||||
|
@ -296,10 +296,4 @@ else
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
|
||||
# export test times so that potential sharded tests that'll branch off this build will use consistent data
|
||||
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
|
||||
python test/run_test.py --export-past-test-times
|
||||
fi
|
||||
|
||||
print_sccache_stats
|
||||
|
@ -146,9 +146,6 @@ python setup.py install --cmake && sccache --show-stats && (
|
||||
if errorlevel 1 exit /b
|
||||
if not errorlevel 0 exit /b
|
||||
|
||||
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
|
||||
python test/run_test.py --export-past-test-times %PYTORCH_FINAL_PACKAGE_DIR%/.pytorch-test-times.json
|
||||
|
||||
:: Also save build/.ninja_log as an artifact
|
||||
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
||||
|
||||
echo Copying over test times file
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
|
||||
|
||||
pushd test
|
||||
|
||||
|
@ -21,9 +21,6 @@ if "%SHARD_NUMBER%" == "1" (
|
||||
)
|
||||
)
|
||||
|
||||
echo Copying over test times file
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
|
||||
|
||||
echo Run nn tests
|
||||
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
|
||||
if ERRORLEVEL 1 goto fail
|
||||
|
@ -32,11 +32,11 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.stats.import_test_stats import get_test_times
|
||||
from tools.testing.test_selections import (
|
||||
export_S3_test_times,
|
||||
get_shard_based_on_S3,
|
||||
get_reordered_tests,
|
||||
get_test_case_configs,
|
||||
calculate_shards,
|
||||
)
|
||||
HAVE_TEST_SELECTION_TOOLS = True
|
||||
except ImportError:
|
||||
@ -677,13 +677,6 @@ def parse_args():
|
||||
help="additional arguments passed through to unittest, e.g., "
|
||||
"python run_test.py -i sparse -- TestSparse.test_factory_size_check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-past-test-times",
|
||||
nargs="?",
|
||||
type=str,
|
||||
const=TEST_TIMES_FILE,
|
||||
help="dumps test times from previous S3 stats into a file, format JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard",
|
||||
nargs=2,
|
||||
@ -838,11 +831,21 @@ def get_selected_tests(options):
|
||||
assert num_shards <= len(
|
||||
selected_tests
|
||||
), f"Number of shards must be less than {len(selected_tests)}"
|
||||
# TODO: fix this to use test_times_filename, but currently this is not working
|
||||
# because setting the export arg immeidately halts the test execution.
|
||||
selected_tests = get_shard_based_on_S3(
|
||||
which_shard, num_shards, selected_tests, TEST_TIMES_FILE
|
||||
|
||||
if num_shards == 1:
|
||||
return selected_tests
|
||||
|
||||
# Download previous test times to make sharding decisions
|
||||
test_file_times = get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE)
|
||||
if len(test_file_times) == 0:
|
||||
print(
|
||||
"::warning:: Gathered no stats from S3. Proceeding with default sharding plan."
|
||||
)
|
||||
selected_tests = selected_tests[which_shard - 1 :: num_shards]
|
||||
else:
|
||||
shards = calculate_shards(num_shards, selected_tests, test_file_times)
|
||||
_, tests_from_shard = shards[which_shard - 1]
|
||||
selected_tests = tests_from_shard
|
||||
|
||||
# skip all distributed tests if distributed package is not available.
|
||||
if not dist.is_available():
|
||||
@ -882,15 +885,6 @@ def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
|
||||
def main():
|
||||
options = parse_args()
|
||||
|
||||
# TODO: move this export & download function in tools/ folder
|
||||
test_times_filename = options.export_past_test_times
|
||||
if test_times_filename:
|
||||
print(
|
||||
f"Exporting past test times from S3 to {test_times_filename}, no tests will be run."
|
||||
)
|
||||
export_S3_test_times(test_times_filename)
|
||||
return
|
||||
|
||||
test_directory = str(REPO_ROOT / "test")
|
||||
selected_tests = get_selected_tests(options)
|
||||
|
||||
|
@ -41,6 +41,7 @@ def fetch_and_cache(
|
||||
This fetch and cache utils allows sharing between different process.
|
||||
"""
|
||||
path = os.path.join(dirpath, name)
|
||||
print(f"Downloading {url} to {path}")
|
||||
|
||||
def is_cached_file_valid() -> bool:
|
||||
# Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
|
||||
@ -80,6 +81,21 @@ def get_slow_tests(
|
||||
return {}
|
||||
|
||||
|
||||
def get_test_times(dirpath: str, filename: str) -> Dict[str, float]:
|
||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json"
|
||||
|
||||
def process_response(the_response: Dict[str, Any]) -> Any:
|
||||
build_environment = os.environ["BUILD_ENVIRONMENT"]
|
||||
test_config = os.environ["TEST_CONFIG"]
|
||||
return the_response[build_environment][test_config]
|
||||
|
||||
try:
|
||||
return fetch_and_cache(dirpath, filename, url, process_response)
|
||||
except Exception:
|
||||
print("Couldn't download test times...")
|
||||
return {}
|
||||
|
||||
|
||||
def get_disabled_tests(
|
||||
dirpath: str, filename: str = DISABLED_TESTS_FILE
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
|
Reference in New Issue
Block a user