mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139816 Approved by: https://github.com/ezyang
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
# Owner(s): ["module: unknown"]
|
|
import concurrent.futures
|
|
import tempfile
|
|
import time
|
|
|
|
from torch.testing._internal.common_utils import run_tests, skipIfWindows, TestCase
|
|
from torch.utils._filelock import FileLock
|
|
|
|
|
|
class TestFileLock(TestCase):
|
|
def test_no_crash(self):
|
|
_, p = tempfile.mkstemp()
|
|
with FileLock(p):
|
|
pass
|
|
|
|
@skipIfWindows(
|
|
msg="Windows doesn't support multiple files being opened at once easily"
|
|
)
|
|
def test_sequencing(self):
|
|
with tempfile.NamedTemporaryFile() as ofd:
|
|
p = ofd.name
|
|
|
|
def test_thread(i):
|
|
with FileLock(p + ".lock"):
|
|
start = time.time()
|
|
with open(p, "a") as fd:
|
|
fd.write(str(i))
|
|
end = time.time()
|
|
return (start, end)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
futures = [executor.submit(test_thread, i) for i in range(10)]
|
|
times = []
|
|
for f in futures:
|
|
times.append(f.result(60))
|
|
|
|
with open(p) as fd:
|
|
self.assertEqual(
|
|
set(fd.read()), {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
|
|
)
|
|
|
|
for i, (start, end) in enumerate(times):
|
|
for j, (newstart, newend) in enumerate(times):
|
|
if i == j:
|
|
continue
|
|
|
|
# Times should never intersect
|
|
self.assertFalse(newstart > start and newstart < end)
|
|
self.assertFalse(newend > start and newstart < end)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|