Compare commits

...

2 Commits

2 changed files with 44 additions and 1 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import argparse
import os
import re
import sys
import xml.etree.ElementTree as ET
from multiprocessing import cpu_count, Pool
@ -19,6 +20,19 @@ from tools.stats.upload_stats_lib import (
)
def should_upload_full_test_run(head_branch: str | None, head_repository: str) -> bool:
"""Return True if we should upload the full test_run dataset.
Rules:
- Only for the main repository (pytorch/pytorch)
- If head_branch is 'main', or a tag of form 'trunk/{40-hex-sha}'
"""
is_trunk_tag = bool(re.fullmatch(r"trunk/[0-9a-fA-F]{40}", (head_branch or "")))
return head_repository == "pytorch/pytorch" and (
head_branch == "main" or is_trunk_tag
)
def parse_xml_report(
tag: str,
report: Path,
@ -287,7 +301,8 @@ if __name__ == "__main__":
remove_nan_inf(failed_tests_cases),
)
if args.head_branch == "main" and args.head_repository == "pytorch/pytorch":
# Upload full test_run only for trusted refs (main or trunk/{sha} tags)
if should_upload_full_test_run(args.head_branch, args.head_repository):
# For jobs on main branch, upload everything.
upload_workflow_stats_to_s3(
args.workflow_run_id,

View File

@ -0,0 +1,28 @@
import unittest
from tools.stats.upload_test_stats import should_upload_full_test_run
class TestUploadGate(unittest.TestCase):
def test_main_branch_on_pytorch_repo(self) -> None:
self.assertTrue(should_upload_full_test_run("main", "pytorch/pytorch"))
def test_trunk_tag_valid_sha_on_pytorch_repo(self) -> None:
sha = "a" * 40
self.assertTrue(should_upload_full_test_run(f"trunk/{sha}", "pytorch/pytorch"))
def test_trunk_tag_invalid_sha_on_pytorch_repo(self) -> None:
# Not 40 hex chars
self.assertFalse(should_upload_full_test_run("trunk/12345", "pytorch/pytorch"))
def test_non_main_branch_on_pytorch_repo(self) -> None:
self.assertFalse(
should_upload_full_test_run("feature-branch", "pytorch/pytorch")
)
def test_main_branch_on_fork_repo(self) -> None:
self.assertFalse(should_upload_full_test_run("main", "someone/fork"))
if __name__ == "__main__":
unittest.main()