# Owner(s): ["module: unknown"] import concurrent.futures import tempfile import time from torch.testing._internal.common_utils import run_tests, skipIfWindows, TestCase from torch.utils._filelock import FileLock class TestFileLock(TestCase): def test_no_crash(self): _, p = tempfile.mkstemp() with FileLock(p): pass @skipIfWindows( msg="Windows doesn't support multiple files being opened at once easily" ) def test_sequencing(self): with tempfile.NamedTemporaryFile() as ofd: p = ofd.name def test_thread(i): with FileLock(p + ".lock"): start = time.time() with open(p, "a") as fd: fd.write(str(i)) end = time.time() return (start, end) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(test_thread, i) for i in range(10)] times = [] for f in futures: times.append(f.result(60)) with open(p) as fd: self.assertEqual( set(fd.read()), {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} ) for i, (start, end) in enumerate(times): for j, (newstart, newend) in enumerate(times): if i == j: continue # Times should never intersect self.assertFalse(newstart > start and newstart < end) self.assertFalse(newend > start and newstart < end) if __name__ == "__main__": run_tests()