[BE][Easy] enable postponed annotations in tools (#129375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-06-29 12:48:06 +08:00
committed by PyTorch MergeBot
parent 58f346c874
commit 8a67daf283
123 changed files with 1274 additions and 1053 deletions

View File

@ -1,16 +1,18 @@
from __future__ import annotations
import gzip
import io
import json
import os
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
import boto3 # type: ignore[import]
import requests
import rockset # type: ignore[import]
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
S3_RESOURCE = boto3.resource("s3")
@ -21,14 +23,14 @@ MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3
BATCH_SIZE = 5000
def _get_request_headers() -> Dict[str, str]:
def _get_request_headers() -> dict[str, str]:
return {
"Accept": "application/vnd.github.v3+json",
"Authorization": "token " + os.environ["GITHUB_TOKEN"],
}
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]:
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]:
"""Get all workflow artifacts with 'test-report' in the name."""
response = requests.get(
f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
@ -78,7 +80,7 @@ def _download_artifact(
def download_s3_artifacts(
prefix: str, workflow_run_id: int, workflow_run_attempt: int
) -> List[Path]:
) -> list[Path]:
bucket = S3_RESOURCE.Bucket("gha-artifacts")
objs = bucket.objects.filter(
Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
@ -104,7 +106,7 @@ def download_s3_artifacts(
def download_gha_artifacts(
prefix: str, workflow_run_id: int, workflow_run_attempt: int
) -> List[Path]:
) -> list[Path]:
artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
paths = []
for name, url in artifact_urls.items():
@ -114,7 +116,7 @@ def download_gha_artifacts(
def upload_to_rockset(
collection: str,
docs: List[Any],
docs: list[Any],
workspace: str = "commons",
client: Any = None,
) -> None:
@ -142,7 +144,7 @@ def upload_to_rockset(
def upload_to_s3(
bucket_name: str,
key: str,
docs: List[Dict[str, Any]],
docs: list[dict[str, Any]],
) -> None:
print(f"Writing {len(docs)} documents to S3")
body = io.StringIO()
@ -164,7 +166,7 @@ def upload_to_s3(
def read_from_s3(
bucket_name: str,
key: str,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
print(f"Reading from s3://{bucket_name}/{key}")
body = (
S3_RESOURCE.Object(
@ -182,7 +184,7 @@ def upload_workflow_stats_to_s3(
workflow_run_id: int,
workflow_run_attempt: int,
collection: str,
docs: List[Dict[str, Any]],
docs: list[dict[str, Any]],
) -> None:
bucket_name = "ossci-raw-job-status"
key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}"
@ -220,7 +222,7 @@ def unzip(p: Path) -> None:
zip.extractall(unzipped_dir)
def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
def is_rerun_disabled_tests(tests: dict[str, dict[str, int]]) -> bool:
"""
Check if the test report is coming from rerun_disabled_tests workflow where
each test is run multiple times
@ -231,7 +233,7 @@ def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
)
def get_job_id(report: Path) -> Optional[int]:
def get_job_id(report: Path) -> int | None:
# [Job id in artifacts]
# Retrieve the job id from the report path. In our GHA workflows, we append
# the job id to the end of the report name, so `report` looks like: