[BE][GHF] Add retries_decorator (#101227)

I've noticed that 3-4 functions in trymerge are trying to implement similar tail recursion for flaky network retries.

Unify them using single wrapper in `gitutils.py`

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 8d40631</samp>

> _`retries_decorator`_
> _adds resilience to GitHub scripts_
> _autumn of errors_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101227
Approved by: https://github.com/kit1980
This commit is contained in:
Nikita Shulga
2023-05-12 20:29:06 +00:00
committed by PyTorch MergeBot
parent 2fcc2002fa
commit 568bac7961
3 changed files with 59 additions and 45 deletions

View File

@ -5,8 +5,21 @@ import re
import tempfile
from collections import defaultdict
from datetime import datetime
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
from functools import wraps
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
T = TypeVar("T")
RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
@ -380,3 +393,24 @@ def are_ghstack_branches_in_sync(repo: GitRepo, head_ref: str) -> bool:
repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}")
)
return orig_diff_sha == head_diff_sha
def retries_decorator(
rc: Any = None, num_retries: int = 3
) -> Callable[[Callable[..., T]], Callable[..., T]]:
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@wraps(f)
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T:
for idx in range(num_retries):
try:
return f(*args, **kwargs)
except Exception as e:
print(
f'Attempt {idx} of {num_retries} to call {f.__name__} failed with "{e}"'
)
pass
return cast(T, rc)
return wrapper
return decorator

View File

@ -8,6 +8,7 @@ from gitutils import (
GitRepo,
patterns_to_regex,
PeekableIterator,
retries_decorator,
)
@ -49,6 +50,22 @@ class TestPattern(TestCase):
self.assertTrue(patterns_re.match(filename))
class TestRetriesDecorator(TestCase):
def test_simple(self) -> None:
@retries_decorator()
def foo(x: int, y: int) -> int:
return x + y
self.assertEqual(foo(3, 4), 7)
def test_fails(self) -> None:
@retries_decorator(rc=0)
def foo(x: int, y: int) -> int:
return x + y
self.assertEqual(foo("a", 4), 0)
class TestGitRepo(TestCase):
def setUp(self) -> None:
repo_dir = BASE_DIR.parent.parent.absolute()

View File

@ -38,6 +38,7 @@ from gitutils import (
get_git_repo_dir,
GitRepo,
patterns_to_regex,
retries_decorator,
)
from label_utils import (
gh_add_labels,
@ -1386,16 +1387,12 @@ def checks_to_markdown_bullets(
]
def _get_flaky_rules(url: str, num_retries: int = 3) -> List[FlakyRule]:
try:
return [FlakyRule(**rule) for rule in gh_fetch_json_list(url)]
except Exception as e:
print(f"Could not download {url} because: {e}.")
if num_retries > 0:
return _get_flaky_rules(url, num_retries=num_retries - 1)
return []
@retries_decorator(rc=[])
def _get_flaky_rules(url: str) -> List[FlakyRule]:
return [FlakyRule(**rule) for rule in gh_fetch_json_list(url)]
@retries_decorator()
def save_merge_record(
collection: str,
comment_id: int,
@ -1414,7 +1411,6 @@ def save_merge_record(
ignore_current: bool = False,
error: str = "",
workspace: str = "commons",
num_retries: int = 3,
) -> None:
"""
This saves the merge records into Rockset, so we can query them (for fun and profit)
@ -1460,35 +1456,9 @@ def save_merge_record(
print("Rockset is missing, no record will be saved")
return
except Exception as e:
if num_retries > 0:
print(f"Could not upload to Rockset ({num_retries - 1} tries left): {e}")
return save_merge_record(
collection=collection,
comment_id=comment_id,
pr_num=pr_num,
owner=owner,
project=project,
author=author,
pending_checks=pending_checks,
failed_checks=failed_checks,
last_commit_sha=last_commit_sha,
merge_base_sha=merge_base_sha,
merge_commit_sha=merge_commit_sha,
is_failed=is_failed,
dry_run=dry_run,
skip_mandatory_checks=skip_mandatory_checks,
ignore_current=ignore_current,
error=error,
workspace=workspace,
num_retries=num_retries - 1,
)
print(f"Could not upload to Rockset ({num_retries} tries left): {e}")
def get_rockset_results(
head_sha: str, merge_base: str, num_retries: int = 3
) -> List[Dict[str, Any]]:
@retries_decorator(rc=[])
def get_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
query = f"""
SELECT
w.name as workflow_name,
@ -1515,13 +1485,6 @@ where
except ModuleNotFoundError:
print("Could not use RockSet as rocket dependency is missing")
return []
except Exception as e:
print(f"Could not download rockset data because: {e}.")
if num_retries > 0:
return get_rockset_results(
head_sha, merge_base, num_retries=num_retries - 1
)
return []
def get_classifications(