[ez][TD] Increase logging (#124082)

increase logging during td
generate an artifact that says which tests got excluded
fix minor bug where filter test configs couldnt get commit messages

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124082
Approved by: https://github.com/seemethere, https://github.com/malfet
This commit is contained in:
Catherine Lee
2024-04-17 00:18:24 +00:00
committed by PyTorch MergeBot
parent e7cf6f81ea
commit 946b50c788
5 changed files with 23 additions and 1 deletions

View File

@ -449,7 +449,7 @@ def parse_reenabled_issues(s: Optional[str]) -> List[str]:
def get_reenabled_issues(pr_body: str = "") -> List[str]:
default_branch = os.getenv("GIT_DEFAULT_BRANCH", "main")
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
try:
commit_messages = subprocess.check_output(
f"git cherry -v {default_branch}".split(" ")

View File

@ -57,6 +57,7 @@ from tools.testing.discover_tests import (
TESTS,
)
from tools.testing.do_target_determination_for_s3 import import_results
from tools.testing.target_determination.gen_artifact import gen_ci_artifact
from tools.testing.test_run import TestRun
from tools.testing.test_selections import (
@ -1714,6 +1715,8 @@ def main():
test_batch = TestBatch("tests to run", include, False)
test_batch_exclude = TestBatch("excluded", exclude, True)
if IS_CI:
gen_ci_artifact([x.to_json() for x in include], [x.to_json() for x in exclude])
print_to_stderr(test_batch)
print_to_stderr(test_batch_exclude)

View File

@ -56,6 +56,9 @@ def main() -> None:
test_prioritizations = aggregated_heuristics.get_aggregated_priorities()
print("Aggregated Heuristics")
print(test_prioritizations.get_info_str(verbose=False))
if os.getenv("CI") == "true":
print("Emitting metrics")
# Split into 3 due to size constraints

View File

@ -0,0 +1,10 @@
import json
import pathlib
from typing import Any, List
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
def gen_ci_artifact(included: List[Any], excluded: List[Any]) -> None:
with open(REPO_ROOT / "test/test-reports/td_exclusions.json", "w") as f:
json.dump({"included": included, "excluded": excluded}, f)

View File

@ -2,6 +2,7 @@ import json
import os
import subprocess
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import cast, Dict, List, Set, Union
from urllib.request import Request, urlopen
@ -20,6 +21,7 @@ def python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
return valid_tests
@lru_cache(maxsize=None)
def query_changed_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
merge_base = (
@ -40,15 +42,18 @@ def query_changed_files() -> List[str]:
capture_output=True,
check=False,
)
print(f"merge_base: {merge_base}, head: {head}")
if proc.returncode != 0:
raise RuntimeError("Unable to get changed files")
lines = proc.stdout.decode().strip().split("\n")
lines = [line.strip() for line in lines]
print(f"Changed files: {lines}")
return lines
@lru_cache(maxsize=None)
def get_git_commit_info() -> str:
"""Gets the commit info since the last commit on the default branch."""
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
@ -75,6 +80,7 @@ def get_git_commit_info() -> str:
)
@lru_cache(maxsize=None)
def get_issue_or_pr_body(number: int) -> str:
"""Gets the body of an issue or PR"""
github_token = os.environ.get("GITHUB_TOKEN")