[BE][Easy] enable postponed annotations in tools (#129375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-06-28 16:28:16 +08:00
committed by PyTorch MergeBot
parent 2e3ff394bf
commit 59eb2897f1
123 changed files with 1274 additions and 1051 deletions

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import json
import os
import time
from typing import Any, Dict, List, Set, Tuple
from typing import Any, TYPE_CHECKING
from ..util.setting import (
CompilerType,
@ -16,7 +18,6 @@ from ..util.utils import (
print_time,
related_to_test_list,
)
from .parser.coverage_record import CoverageRecord
from .parser.gcov_coverage_parser import GcovCoverageParser
from .parser.llvm_coverage_parser import LlvmCoverageParser
from .print_report import (
@ -26,16 +27,20 @@ from .print_report import (
)
if TYPE_CHECKING:
from .parser.coverage_record import CoverageRecord
# coverage_records: Dict[str, LineInfo] = {}
covered_lines: Dict[str, Set[int]] = {}
uncovered_lines: Dict[str, Set[int]] = {}
covered_lines: dict[str, set[int]] = {}
uncovered_lines: dict[str, set[int]] = {}
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
def transform_file_name(
file_path: str, interested_folders: List[str], platform: TestPlatform
file_path: str, interested_folders: list[str], platform: TestPlatform
) -> str:
remove_patterns: Set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
remove_patterns: set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
for pattern in remove_patterns:
file_path = file_path.replace(pattern, "")
# if user has specified interested folder
@ -54,7 +59,7 @@ def transform_file_name(
def is_intrested_file(
file_path: str, interested_folders: List[str], platform: TestPlatform
file_path: str, interested_folders: list[str], platform: TestPlatform
) -> bool:
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
if any(pattern in file_path for pattern in ignored_patterns):
@ -77,7 +82,7 @@ def is_intrested_file(
return True
def get_json_obj(json_file: str) -> Tuple[Any, int]:
def get_json_obj(json_file: str) -> tuple[Any, int]:
"""
Sometimes at the start of file llvm/gcov will complains "fail to find coverage data",
then we need to skip these lines
@ -102,7 +107,7 @@ def get_json_obj(json_file: str) -> Tuple[Any, int]:
return None, 2
def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
def parse_json(json_file: str, platform: TestPlatform) -> list[CoverageRecord]:
print("start parse:", json_file)
json_obj, read_status = get_json_obj(json_file)
if read_status == 0:
@ -117,7 +122,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
cov_type = detect_compiler_type(platform)
coverage_records: List[CoverageRecord] = []
coverage_records: list[CoverageRecord] = []
if cov_type == CompilerType.CLANG:
coverage_records = LlvmCoverageParser(json_obj).parse("fbcode")
# print(coverage_records)
@ -128,7 +133,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
def parse_jsons(
test_list: TestList, interested_folders: List[str], platform: TestPlatform
test_list: TestList, interested_folders: list[str], platform: TestPlatform
) -> None:
g = os.walk(JSON_FOLDER_BASE_DIR)
@ -152,8 +157,8 @@ def parse_jsons(
def update_coverage(
coverage_records: List[CoverageRecord],
interested_folders: List[str],
coverage_records: list[CoverageRecord],
interested_folders: list[str],
platform: TestPlatform,
) -> None:
for item in coverage_records:
@ -187,8 +192,8 @@ def update_set() -> None:
def summarize_jsons(
test_list: TestList,
interested_folders: List[str],
coverage_only: List[str],
interested_folders: list[str],
coverage_only: list[str],
platform: TestPlatform,
) -> None:
start_time = time.time()