mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: This will make https://github.com/pytorch/test-infra/commits/master more readable in the future Pull Request resolved: https://github.com/pytorch/pytorch/pull/55862 Reviewed By: ngimel Differential Revision: D27728462 Pulled By: malfet fbshipit-source-id: 2f10dd7ace49f343c4b91fc02be9d955fdbf67cc
71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
#!/usr/bin/env python
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import statistics
|
|
from collections import defaultdict
|
|
from tools.stats_utils.s3_stat_parser import get_previous_reports_for_branch, Report, Version2Report
|
|
from typing import cast, DefaultDict, Dict, List
|
|
|
|
SLOW_TESTS_FILE = '.pytorch-slow-tests'
|
|
SLOW_TEST_CASE_THRESHOLD_SEC = 60.0
|
|
|
|
|
|
def get_test_case_times() -> Dict[str, float]:
|
|
reports: List[Report] = get_previous_reports_for_branch('origin/viable/strict', "")
|
|
# an entry will be like ("test_doc_examples (__main__.TestTypeHints)" -> [values]))
|
|
test_names_to_times: DefaultDict[str, List[float]] = defaultdict(list)
|
|
for report in reports:
|
|
if report.get('format_version', 1) != 2:
|
|
raise RuntimeError("S3 format currently handled is version 2 only")
|
|
v2report = cast(Version2Report, report)
|
|
for test_file in v2report['files'].values():
|
|
for suitename, test_suite in test_file['suites'].items():
|
|
for casename, test_case in test_suite['cases'].items():
|
|
# The below attaches a __main__ as that matches the format of test.__class__ in
|
|
# common_utils.py (where this data will be used), and also matches what the output
|
|
# of a running test would look like.
|
|
name = f'{casename} (__main__.{suitename})'
|
|
succeeded: bool = test_case['status'] is None
|
|
if succeeded:
|
|
test_names_to_times[name].append(test_case['seconds'])
|
|
return {test_case: statistics.mean(times) for test_case, times in test_names_to_times.items()}
|
|
|
|
|
|
def filter_slow_tests(test_cases_dict: Dict[str, float]) -> Dict[str, float]:
|
|
return {test_case: time for test_case, time in test_cases_dict.items() if time >= SLOW_TEST_CASE_THRESHOLD_SEC}
|
|
|
|
|
|
def export_slow_tests(filename: str) -> None:
|
|
if os.path.exists(filename):
|
|
print(f'Overwriting existent file: {filename}')
|
|
with open(filename, 'w+') as file:
|
|
slow_test_times: Dict[str, float] = filter_slow_tests(get_test_case_times())
|
|
json.dump(slow_test_times, file, indent=' ', separators=(',', ': '), sort_keys=True)
|
|
file.write('\n')
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Export a JSON of slow test cases in PyTorch unit test suite')
|
|
parser.add_argument(
|
|
'-f',
|
|
'--filename',
|
|
nargs='?',
|
|
type=str,
|
|
default=SLOW_TESTS_FILE,
|
|
const=SLOW_TESTS_FILE,
|
|
help='Specify a file path to dump slow test times from previous S3 stats. Default file path: .pytorch-slow-tests',
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
options = parse_args()
|
|
export_slow_tests(options.filename)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|