mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 2293fe1024812d6349f6e2b3b7de82c6b73f11e4. Reverted https://github.com/pytorch/pytorch/pull/129374 on behalf of https://github.com/malfet due to failing internal ROCM builds with error: ModuleNotFoundError: No module named hipify ([comment](https://github.com/pytorch/pytorch/pull/129374#issuecomment-2562973920))
103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
#!/usr/bin/env python3
|
|
from pathlib import Path
|
|
from unittest import main, SkipTest, TestCase
|
|
|
|
from gitutils import (
|
|
_shasum,
|
|
are_ghstack_branches_in_sync,
|
|
GitRepo,
|
|
patterns_to_regex,
|
|
PeekableIterator,
|
|
retries_decorator,
|
|
)
|
|
|
|
|
|
BASE_DIR = Path(__file__).parent
|
|
|
|
|
|
class TestPeekableIterator(TestCase):
|
|
def test_iterator(self, input_: str = "abcdef") -> None:
|
|
iter_ = PeekableIterator(input_)
|
|
for idx, c in enumerate(iter_):
|
|
self.assertEqual(c, input_[idx])
|
|
|
|
def test_is_iterable(self) -> None:
|
|
from collections.abc import Iterator
|
|
|
|
iter_ = PeekableIterator("")
|
|
self.assertTrue(isinstance(iter_, Iterator))
|
|
|
|
def test_peek(self, input_: str = "abcdef") -> None:
|
|
iter_ = PeekableIterator(input_)
|
|
for idx, c in enumerate(iter_):
|
|
if idx + 1 < len(input_):
|
|
self.assertEqual(iter_.peek(), input_[idx + 1])
|
|
else:
|
|
self.assertTrue(iter_.peek() is None)
|
|
|
|
|
|
class TestPattern(TestCase):
|
|
def test_double_asterisks(self) -> None:
|
|
allowed_patterns = [
|
|
"aten/src/ATen/native/**LinearAlgebra*",
|
|
]
|
|
patterns_re = patterns_to_regex(allowed_patterns)
|
|
fnames = [
|
|
"aten/src/ATen/native/LinearAlgebra.cpp",
|
|
"aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp",
|
|
]
|
|
for filename in fnames:
|
|
self.assertTrue(patterns_re.match(filename))
|
|
|
|
|
|
class TestRetriesDecorator(TestCase):
|
|
def test_simple(self) -> None:
|
|
@retries_decorator()
|
|
def foo(x: int, y: int) -> int:
|
|
return x + y
|
|
|
|
self.assertEqual(foo(3, 4), 7)
|
|
|
|
def test_fails(self) -> None:
|
|
@retries_decorator(rc=0)
|
|
def foo(x: int, y: int) -> int:
|
|
return x + y
|
|
|
|
self.assertEqual(foo("a", 4), 0)
|
|
|
|
|
|
class TestGitRepo(TestCase):
|
|
def setUp(self) -> None:
|
|
repo_dir = BASE_DIR.parent.parent.absolute()
|
|
if not (repo_dir / ".git").is_dir():
|
|
raise SkipTest(
|
|
"Can't find git directory, make sure to run this test on real repo checkout"
|
|
)
|
|
self.repo = GitRepo(str(repo_dir))
|
|
|
|
def _skip_if_ref_does_not_exist(self, ref: str) -> None:
|
|
"""Skip test if ref is missing as stale branches are deleted with time"""
|
|
try:
|
|
self.repo.show_ref(ref)
|
|
except RuntimeError as e:
|
|
raise SkipTest(f"Can't find head ref {ref} due to {str(e)}") from e
|
|
|
|
def test_compute_diff(self) -> None:
|
|
diff = self.repo.diff("HEAD")
|
|
sha = _shasum(diff)
|
|
self.assertEqual(len(sha), 64)
|
|
|
|
def test_ghstack_branches_in_sync(self) -> None:
|
|
head_ref = "gh/SS-JIA/206/head"
|
|
self._skip_if_ref_does_not_exist(head_ref)
|
|
self.assertTrue(are_ghstack_branches_in_sync(self.repo, head_ref))
|
|
|
|
def test_ghstack_branches_not_in_sync(self) -> None:
|
|
head_ref = "gh/clee2000/1/head"
|
|
self._skip_if_ref_does_not_exist(head_ref)
|
|
self.assertFalse(are_ghstack_branches_in_sync(self.repo, head_ref))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|