mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][GHF] Add retries_decorator
(#101227)
I've noticed that 3-4 functions in trymerge are trying to implement similar tail recursion for flaky network retries. Unify them using single wrapper in `gitutils.py` <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at 8d40631</samp> > _`retries_decorator`_ > _adds resilience to GitHub scripts_ > _autumn of errors_ Pull Request resolved: https://github.com/pytorch/pytorch/pull/101227 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
2fcc2002fa
commit
568bac7961
36
.github/scripts/gitutils.py
vendored
36
.github/scripts/gitutils.py
vendored
@ -5,8 +5,21 @@ import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
|
||||
|
||||
@ -380,3 +393,24 @@ def are_ghstack_branches_in_sync(repo: GitRepo, head_ref: str) -> bool:
|
||||
repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}")
|
||||
)
|
||||
return orig_diff_sha == head_diff_sha
|
||||
|
||||
|
||||
def retries_decorator(
|
||||
rc: Any = None, num_retries: int = 3
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(f)
|
||||
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T:
|
||||
for idx in range(num_retries):
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'Attempt {idx} of {num_retries} to call {f.__name__} failed with "{e}"'
|
||||
)
|
||||
pass
|
||||
return cast(T, rc)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
17
.github/scripts/test_gitutils.py
vendored
17
.github/scripts/test_gitutils.py
vendored
@ -8,6 +8,7 @@ from gitutils import (
|
||||
GitRepo,
|
||||
patterns_to_regex,
|
||||
PeekableIterator,
|
||||
retries_decorator,
|
||||
)
|
||||
|
||||
|
||||
@ -49,6 +50,22 @@ class TestPattern(TestCase):
|
||||
self.assertTrue(patterns_re.match(filename))
|
||||
|
||||
|
||||
class TestRetriesDecorator(TestCase):
|
||||
def test_simple(self) -> None:
|
||||
@retries_decorator()
|
||||
def foo(x: int, y: int) -> int:
|
||||
return x + y
|
||||
|
||||
self.assertEqual(foo(3, 4), 7)
|
||||
|
||||
def test_fails(self) -> None:
|
||||
@retries_decorator(rc=0)
|
||||
def foo(x: int, y: int) -> int:
|
||||
return x + y
|
||||
|
||||
self.assertEqual(foo("a", 4), 0)
|
||||
|
||||
|
||||
class TestGitRepo(TestCase):
|
||||
def setUp(self) -> None:
|
||||
repo_dir = BASE_DIR.parent.parent.absolute()
|
||||
|
51
.github/scripts/trymerge.py
vendored
51
.github/scripts/trymerge.py
vendored
@ -38,6 +38,7 @@ from gitutils import (
|
||||
get_git_repo_dir,
|
||||
GitRepo,
|
||||
patterns_to_regex,
|
||||
retries_decorator,
|
||||
)
|
||||
from label_utils import (
|
||||
gh_add_labels,
|
||||
@ -1386,16 +1387,12 @@ def checks_to_markdown_bullets(
|
||||
]
|
||||
|
||||
|
||||
def _get_flaky_rules(url: str, num_retries: int = 3) -> List[FlakyRule]:
|
||||
try:
|
||||
return [FlakyRule(**rule) for rule in gh_fetch_json_list(url)]
|
||||
except Exception as e:
|
||||
print(f"Could not download {url} because: {e}.")
|
||||
if num_retries > 0:
|
||||
return _get_flaky_rules(url, num_retries=num_retries - 1)
|
||||
return []
|
||||
@retries_decorator(rc=[])
|
||||
def _get_flaky_rules(url: str) -> List[FlakyRule]:
|
||||
return [FlakyRule(**rule) for rule in gh_fetch_json_list(url)]
|
||||
|
||||
|
||||
@retries_decorator()
|
||||
def save_merge_record(
|
||||
collection: str,
|
||||
comment_id: int,
|
||||
@ -1414,7 +1411,6 @@ def save_merge_record(
|
||||
ignore_current: bool = False,
|
||||
error: str = "",
|
||||
workspace: str = "commons",
|
||||
num_retries: int = 3,
|
||||
) -> None:
|
||||
"""
|
||||
This saves the merge records into Rockset, so we can query them (for fun and profit)
|
||||
@ -1460,35 +1456,9 @@ def save_merge_record(
|
||||
print("Rockset is missing, no record will be saved")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if num_retries > 0:
|
||||
print(f"Could not upload to Rockset ({num_retries - 1} tries left): {e}")
|
||||
return save_merge_record(
|
||||
collection=collection,
|
||||
comment_id=comment_id,
|
||||
pr_num=pr_num,
|
||||
owner=owner,
|
||||
project=project,
|
||||
author=author,
|
||||
pending_checks=pending_checks,
|
||||
failed_checks=failed_checks,
|
||||
last_commit_sha=last_commit_sha,
|
||||
merge_base_sha=merge_base_sha,
|
||||
merge_commit_sha=merge_commit_sha,
|
||||
is_failed=is_failed,
|
||||
dry_run=dry_run,
|
||||
skip_mandatory_checks=skip_mandatory_checks,
|
||||
ignore_current=ignore_current,
|
||||
error=error,
|
||||
workspace=workspace,
|
||||
num_retries=num_retries - 1,
|
||||
)
|
||||
print(f"Could not upload to Rockset ({num_retries} tries left): {e}")
|
||||
|
||||
|
||||
def get_rockset_results(
|
||||
head_sha: str, merge_base: str, num_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
@retries_decorator(rc=[])
|
||||
def get_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
|
||||
query = f"""
|
||||
SELECT
|
||||
w.name as workflow_name,
|
||||
@ -1515,13 +1485,6 @@ where
|
||||
except ModuleNotFoundError:
|
||||
print("Could not use RockSet as rocket dependency is missing")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Could not download rockset data because: {e}.")
|
||||
if num_retries > 0:
|
||||
return get_rockset_results(
|
||||
head_sha, merge_base, num_retries=num_retries - 1
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def get_classifications(
|
||||
|
Reference in New Issue
Block a user