mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
filelock: Make waitcounter variant to use (#139816)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139816 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
6cb6e8d790
commit
d68403df3b
53
test/test_utils_filelock.py
Normal file
53
test/test_utils_filelock.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
import concurrent.futures
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfWindows, TestCase
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
|
||||
class TestFileLock(TestCase):
|
||||
def test_no_crash(self):
|
||||
_, p = tempfile.mkstemp()
|
||||
with FileLock(p):
|
||||
pass
|
||||
|
||||
@skipIfWindows(
|
||||
msg="Windows doesn't support multiple files being opened at once easily"
|
||||
)
|
||||
def test_sequencing(self):
|
||||
with tempfile.NamedTemporaryFile() as ofd:
|
||||
p = ofd.name
|
||||
|
||||
def test_thread(i):
|
||||
with FileLock(p + ".lock"):
|
||||
start = time.time()
|
||||
with open(p, "a") as fd:
|
||||
fd.write(str(i))
|
||||
end = time.time()
|
||||
return (start, end)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(test_thread, i) for i in range(10)]
|
||||
times = []
|
||||
for f in futures:
|
||||
times.append(f.result(60))
|
||||
|
||||
with open(p) as fd:
|
||||
self.assertEqual(
|
||||
set(fd.read()), {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
|
||||
)
|
||||
|
||||
for i, (start, end) in enumerate(times):
|
||||
for j, (newstart, newend) in enumerate(times):
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
# Times should never intersect
|
||||
self.assertFalse(newstart > start and newstart < end)
|
||||
self.assertFalse(newend > start and newstart < end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -659,7 +659,7 @@ def put_local_code_state(cache_key: str) -> None:
|
||||
lock_path = path + ".lock"
|
||||
# We /mostly/ don't need the lock but the tmp file could be clobbered
|
||||
# TODO: use a safe tempfile create to eliminate lock
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
|
@ -20,10 +20,9 @@ def aoti_eager_cache_dir(namespace: str, device: str) -> Path:
|
||||
|
||||
|
||||
def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
|
||||
from filelock import FileLock
|
||||
|
||||
# Avoid circular import
|
||||
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
op_conf_lock_file = f"{op_func_name_with_overload}.lock"
|
||||
lock_dir = get_lock_dir()
|
||||
|
@ -1568,7 +1568,7 @@ class AotCodeCompiler:
|
||||
pos += rc
|
||||
return consts_o
|
||||
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
@ -2003,7 +2003,7 @@ class CppCodeCache:
|
||||
key, input_path = write(source_code, "cpp", extra=vec_isa_cmd)
|
||||
|
||||
if key not in cls.cache:
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
lock_path = os.path.join(get_lock_dir(), key + ".lock")
|
||||
output_name, output_dir = get_name_and_dir_from_output_file_path(input_path)
|
||||
@ -2068,7 +2068,7 @@ def _worker_compile_cpp(
|
||||
fb_input_path: str,
|
||||
fb_output_path: str,
|
||||
) -> None:
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
|
||||
binary_path = (
|
||||
@ -2646,10 +2646,11 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
afile = str(dirpath / "standalone_halide_runtime.a")
|
||||
sofile = str(dirpath / libname)
|
||||
if not os.path.exists(donefile):
|
||||
import filelock
|
||||
import halide as hl # type: ignore[import-untyped,import-not-found]
|
||||
|
||||
with filelock.FileLock(lockfile, LOCK_TIMEOUT):
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
with FileLock(lockfile, LOCK_TIMEOUT):
|
||||
if not os.path.exists(donefile):
|
||||
with open(hookfile, "w") as f:
|
||||
if device_type == "cuda":
|
||||
@ -2680,7 +2681,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
|
||||
|
||||
def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None:
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
try:
|
||||
with FileLock(lockfile, LOCK_TIMEOUT):
|
||||
@ -3075,7 +3076,7 @@ class CUDACodeCache:
|
||||
"""
|
||||
key, input_path = cls.write(source_code, dst_file_ext)
|
||||
if key not in cls.cache:
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
@ -3166,7 +3167,7 @@ class ROCmCodeCache:
|
||||
|
||||
key, input_path = cls.write(source_code, dst_file_ext)
|
||||
if key not in cls.cache:
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
|
@ -79,7 +79,7 @@ def cpp_compiler_search(search: str) -> str:
|
||||
# Do not install GXX by default
|
||||
if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"):
|
||||
continue
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(
|
||||
|
@ -101,7 +101,7 @@ cdll.LoadLibrary("__lib_path__")
|
||||
"cpp",
|
||||
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
|
||||
)
|
||||
from filelock import FileLock
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
|
@ -20,12 +20,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Un
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
from filelock import FileLock
|
||||
|
||||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
from . import config, ir
|
||||
from .autotune_process import (
|
||||
|
42
torch/utils/_filelock.py
Normal file
42
torch/utils/_filelock.py
Normal file
@ -0,0 +1,42 @@
|
||||
from types import TracebackType
|
||||
from typing import Optional
|
||||
from typing_extensions import Self
|
||||
|
||||
from filelock import FileLock as base_FileLock
|
||||
|
||||
from torch.monitor import _WaitCounter
|
||||
|
||||
|
||||
class FileLock(base_FileLock):
|
||||
"""
|
||||
This behaves like a normal file lock.
|
||||
|
||||
However, it adds waitcounters for acquiring and releasing the filelock
|
||||
as well as for the critical region within it.
|
||||
|
||||
pytorch.filelock.enter - While we're acquiring the filelock.
|
||||
pytorch.filelock.region - While we're holding the filelock and doing work.
|
||||
pytorch.filelock.exit - While we're releasing the filelock.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.region_counter = _WaitCounter("pytorch.filelock.region").guard()
|
||||
with _WaitCounter("pytorch.filelock.enter").guard():
|
||||
result = super().__enter__()
|
||||
self.region_counter.__enter__()
|
||||
return result
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.region_counter.__exit__()
|
||||
with _WaitCounter("pytorch.filelock.exit").guard():
|
||||
# Returns nothing per
|
||||
# https://github.com/tox-dev/filelock/blob/57f488ff8fdc2193572efe102408fb63cfefe4e4/src/filelock/_api.py#L379
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
# Returns nothing per
|
||||
# https://github.com/pytorch/pytorch/blob/0f6bfc58a2cfb7a5c052bea618ab62becaf5c912/torch/csrc/monitor/python_init.cpp#L315
|
||||
return None
|
Reference in New Issue
Block a user