#!/usr/bin/env python3 from pathlib import Path from unittest import main, SkipTest, TestCase from gitutils import ( _shasum, are_ghstack_branches_in_sync, GitRepo, patterns_to_regex, PeekableIterator, retries_decorator, ) BASE_DIR = Path(__file__).parent class TestPeekableIterator(TestCase): def test_iterator(self, input_: str = "abcdef") -> None: iter_ = PeekableIterator(input_) for idx, c in enumerate(iter_): self.assertEqual(c, input_[idx]) def test_is_iterable(self) -> None: from collections.abc import Iterator iter_ = PeekableIterator("") self.assertTrue(isinstance(iter_, Iterator)) def test_peek(self, input_: str = "abcdef") -> None: iter_ = PeekableIterator(input_) for idx, c in enumerate(iter_): if idx + 1 < len(input_): self.assertEqual(iter_.peek(), input_[idx + 1]) else: self.assertTrue(iter_.peek() is None) class TestPattern(TestCase): def test_double_asterisks(self) -> None: allowed_patterns = [ "aten/src/ATen/native/**LinearAlgebra*", ] patterns_re = patterns_to_regex(allowed_patterns) fnames = [ "aten/src/ATen/native/LinearAlgebra.cpp", "aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp", ] for filename in fnames: self.assertTrue(patterns_re.match(filename)) class TestRetriesDecorator(TestCase): def test_simple(self) -> None: @retries_decorator() def foo(x: int, y: int) -> int: return x + y self.assertEqual(foo(3, 4), 7) def test_fails(self) -> None: @retries_decorator(rc=0) def foo(x: int, y: int) -> int: return x + y self.assertEqual(foo("a", 4), 0) class TestGitRepo(TestCase): def setUp(self) -> None: repo_dir = BASE_DIR.absolute().parent.parent if not (repo_dir / ".git").is_dir(): raise SkipTest( "Can't find git directory, make sure to run this test on real repo checkout" ) self.repo = GitRepo(str(repo_dir)) def _skip_if_ref_does_not_exist(self, ref: str) -> None: """Skip test if ref is missing as stale branches are deleted with time""" try: self.repo.show_ref(ref) except RuntimeError as e: raise SkipTest(f"Can't find head ref {ref} due to {str(e)}") from e def test_compute_diff(self) -> None: diff = self.repo.diff("HEAD") sha = _shasum(diff) self.assertEqual(len(sha), 64) def test_ghstack_branches_in_sync(self) -> None: head_ref = "gh/SS-JIA/206/head" self._skip_if_ref_does_not_exist(head_ref) self.assertTrue(are_ghstack_branches_in_sync(self.repo, head_ref)) def test_ghstack_branches_not_in_sync(self) -> None: head_ref = "gh/clee2000/1/head" self._skip_if_ref_does_not_exist(head_ref) self.assertFalse(are_ghstack_branches_in_sync(self.repo, head_ref)) if __name__ == "__main__": main()