mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375 Approved by: https://github.com/malfet
		
			
				
	
	
		
			108 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			108 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import argparse
 | |
| import os
 | |
| from typing import Any
 | |
| 
 | |
| 
 | |
| try:
 | |
|     from junitparser import (  # type: ignore[import]
 | |
|         Error,
 | |
|         Failure,
 | |
|         JUnitXml,
 | |
|         TestCase,
 | |
|         TestSuite,
 | |
|     )
 | |
| except ImportError as e:
 | |
|     raise ImportError(
 | |
|         "junitparser not found, please install with 'pip install junitparser'"
 | |
|     ) from e
 | |
| 
 | |
| try:
 | |
|     import rich
 | |
| except ImportError:
 | |
|     print("rich not found, for color output use 'pip install rich'")
 | |
| 
 | |
| 
 | |
| def parse_junit_reports(path_to_reports: str) -> list[TestCase]:  # type: ignore[no-any-unimported]
 | |
|     def parse_file(path: str) -> list[TestCase]:  # type: ignore[no-any-unimported]
 | |
|         try:
 | |
|             return convert_junit_to_testcases(JUnitXml.fromfile(path))
 | |
|         except Exception as err:
 | |
|             rich.print(
 | |
|                 f":Warning: [yellow]Warning[/yellow]: Failed to read {path}: {err}"
 | |
|             )
 | |
|             return []
 | |
| 
 | |
|     if not os.path.exists(path_to_reports):
 | |
|         raise FileNotFoundError(f"Path '{path_to_reports}', not found")
 | |
|     # Return early if the path provided is just a file
 | |
|     if os.path.isfile(path_to_reports):
 | |
|         return parse_file(path_to_reports)
 | |
|     ret_xml = []
 | |
|     if os.path.isdir(path_to_reports):
 | |
|         for root, _, files in os.walk(path_to_reports):
 | |
|             for fname in [f for f in files if f.endswith("xml")]:
 | |
|                 ret_xml += parse_file(os.path.join(root, fname))
 | |
|     return ret_xml
 | |
| 
 | |
| 
 | |
| def convert_junit_to_testcases(xml: JUnitXml | TestSuite) -> list[TestCase]:  # type: ignore[no-any-unimported]
 | |
|     testcases = []
 | |
|     for item in xml:
 | |
|         if isinstance(item, TestSuite):
 | |
|             testcases.extend(convert_junit_to_testcases(item))
 | |
|         else:
 | |
|             testcases.append(item)
 | |
|     return testcases
 | |
| 
 | |
| 
 | |
| def render_tests(testcases: list[TestCase]) -> None:  # type: ignore[no-any-unimported]
 | |
|     num_passed = 0
 | |
|     num_skipped = 0
 | |
|     num_failed = 0
 | |
|     for testcase in testcases:
 | |
|         if not testcase.result:
 | |
|             num_passed += 1
 | |
|             continue
 | |
|         for result in testcase.result:
 | |
|             if isinstance(result, Error):
 | |
|                 icon = ":rotating_light: [white on red]ERROR[/white on red]:"
 | |
|                 num_failed += 1
 | |
|             elif isinstance(result, Failure):
 | |
|                 icon = ":x: [white on red]Failure[/white on red]:"
 | |
|                 num_failed += 1
 | |
|             else:
 | |
|                 num_skipped += 1
 | |
|                 continue
 | |
|             rich.print(
 | |
|                 f"{icon} [bold red]{testcase.classname}.{testcase.name}[/bold red]"
 | |
|             )
 | |
|             print(f"{result.text}")
 | |
|     rich.print(f":white_check_mark: {num_passed} [green]Passed[green]")
 | |
|     rich.print(f":dash: {num_skipped} [grey]Skipped[grey]")
 | |
|     rich.print(f":rotating_light: {num_failed} [grey]Failed[grey]")
 | |
| 
 | |
| 
 | |
| def parse_args() -> Any:
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="Render xunit output for failed tests",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "report_path",
 | |
|         help="Base xunit reports (single file or directory) to compare to",
 | |
|     )
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def main() -> None:
 | |
|     options = parse_args()
 | |
|     testcases = parse_junit_reports(options.report_path)
 | |
|     render_tests(testcases)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |